You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:31 UTC

[05/23] systemml git commit: TernaryAggregate now applies to a power of 3.

TernaryAggregate now applies to a power of 3.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f005d949
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f005d949
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f005d949

Branch: refs/heads/master
Commit: f005d94997d9c17ad8e90b4d2bd340f81b9a752d
Parents: 8b832f6
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Fri Jun 9 22:06:10 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:24 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  | 67 ++++++++++++--------
 .../functions/misc/RewriteEMultChainTest.java   |  7 +-
 .../functions/misc/RewriteEMultChainOp.R        | 33 ----------
 .../functions/misc/RewriteEMultChainOp.dml      | 28 --------
 .../functions/misc/RewriteEMultChainOpXYX.R     | 33 ++++++++++
 .../functions/misc/RewriteEMultChainOpXYX.dml   | 28 ++++++++
 6 files changed, 106 insertions(+), 90 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 4573b66..300a20c 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -490,29 +490,35 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 			(_direction == Direction.RowCol || _direction == Direction.Col)  ) 
 		{
 			Hop input1 = getInput().get(0);
-			if( input1.getParent().size() == 1 && //sum single consumer
-				input1 instanceof BinaryOp && ((BinaryOp)input1).getOp()==OpOp2.MULT
-				// As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it.
-				&& input1.optFindExecType() != ExecType.MR) 
-			{
-				Hop input11 = input1.getInput().get(0);
-				Hop input12 = input1.getInput().get(1);
-				
-				if( input11 instanceof BinaryOp && ((BinaryOp)input11).getOp()==OpOp2.MULT ) {
-					//ternary, arbitrary matrices but no mv/outer operations.
-					ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1)
-						&& HopRewriteUtils.isEqualSize(input11.getInput().get(1), input1)	
-						&& HopRewriteUtils.isEqualSize(input12, input1);
-				}
-				else if( input12 instanceof BinaryOp && ((BinaryOp)input12).getOp()==OpOp2.MULT ) {
-					//ternary, arbitrary matrices but no mv/outer operations.
-					ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1)
-							&& HopRewriteUtils.isEqualSize(input12.getInput().get(1), input1)	
-							&& HopRewriteUtils.isEqualSize(input11, input1);
+			if (input1.getParent().size() == 1
+					&& input1 instanceof BinaryOp) { //sum single consumer
+				BinaryOp binput1 = (BinaryOp)input1;
+
+				if (binput1.getOp() == OpOp2.POW
+						&& binput1.getInput().get(1) instanceof LiteralOp) {
+					LiteralOp lit = (LiteralOp)binput1.getInput().get(1);
+					ret = lit.getLongValue() == 3;
 				}
-				else {
-					//binary, arbitrary matrices but no mv/outer operations.
-					ret = HopRewriteUtils.isEqualSize(input11, input12);
+				else if (binput1.getOp() == OpOp2.MULT
+						// As unary agg instruction is not implemented in MR and since MR is in maintenance mode, postponed it.
+						&& input1.optFindExecType() != ExecType.MR) {
+					Hop input11 = input1.getInput().get(0);
+					Hop input12 = input1.getInput().get(1);
+
+					if (input11 instanceof BinaryOp && ((BinaryOp) input11).getOp() == OpOp2.MULT) {
+						//ternary, arbitrary matrices but no mv/outer operations.
+						ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1) && HopRewriteUtils
+								.isEqualSize(input11.getInput().get(1), input1) && HopRewriteUtils
+								.isEqualSize(input12, input1);
+					} else if (input12 instanceof BinaryOp && ((BinaryOp) input12).getOp() == OpOp2.MULT) {
+						//ternary, arbitrary matrices but no mv/outer operations.
+						ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1) && HopRewriteUtils
+								.isEqualSize(input12.getInput().get(1), input1) && HopRewriteUtils
+								.isEqualSize(input11, input1);
+					} else {
+						//binary, arbitrary matrices but no mv/outer operations.
+						ret = HopRewriteUtils.isEqualSize(input11, input12);
+					}
 				}
 			}
 		}
@@ -626,14 +632,25 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
 	private Lop constructLopsTernaryAggregateRewrite(ExecType et) 
 		throws HopsException, LopsException
 	{
-		Hop input1 = getInput().get(0);
+		BinaryOp input1 = (BinaryOp)getInput().get(0);
 		Hop input11 = input1.getInput().get(0);
 		Hop input12 = input1.getInput().get(1);
 		
 		Lop in1 = null, in2 = null, in3 = null;
 		boolean handled = false;
-		
-		if( input11 instanceof BinaryOp ) {
+
+		if (input1.getOp() == OpOp2.POW) {
+			switch ((int)((LiteralOp)input12).getLongValue()) {
+			case 3:
+				in1 = input11.constructLops();
+				in2 = in1;
+				in3 = in1;
+				break;
+			default:
+				throw new AssertionError("unreachable; only applies to power 3");
+			}
+			handled = true;
+		} else if (input11 instanceof BinaryOp ) {
 			BinaryOp b11 = (BinaryOp)input11;
 			switch (b11.getOp()) {
 			case MULT: // A*B*C case

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
index 18ed55d..85dbea4 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteEMultChainTest.java
@@ -37,7 +37,7 @@ import org.junit.Test;
  */
 public class RewriteEMultChainTest extends AutomatedTestBase
 {
-	private static final String TEST_NAME1 = "RewriteEMultChainOp";
+	private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
 	private static final String TEST_DIR = "functions/misc/";
 	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteEMultChainTest.class.getSimpleName() + "/";
 	
@@ -94,8 +94,7 @@ public class RewriteEMultChainTest extends AutomatedTestBase
 			
 			String HOME = SCRIPT_DIR + TEST_DIR;
 			fullDMLScriptName = HOME + testname + ".dml";
-			programArgs = new String[]{ "-explain", "hops", "-stats", 
-				"-args", input("X"), input("Y"), output("R") };
+			programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
 			fullRScriptName = HOME + testname + ".R";
 			rCmd = getRCmd(inputDir(), expectedDir());			
 
@@ -104,7 +103,7 @@ public class RewriteEMultChainTest extends AutomatedTestBase
 			double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
 			writeInputMatrixWithMTD("X", X, true);
 			writeInputMatrixWithMTD("Y", Y, true);
-			
+
 			//execute tests
 			runTest(true, false, null, -1); 
 			runRScript(true); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOp.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.R b/src/test/scripts/functions/misc/RewriteEMultChainOp.R
deleted file mode 100644
index 6d94cc8..0000000
--- a/src/test/scripts/functions/misc/RewriteEMultChainOp.R
+++ /dev/null
@@ -1,33 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-# 
-#   http://www.apache.org/licenses/LICENSE-2.0
-# 
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-
-args <- commandArgs(TRUE)
-options(digits=22)
-library("Matrix")
-library("matrixStats")
-
-X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
-Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
-
-R = X * Y * X;
-
-writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml b/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
deleted file mode 100644
index 3992403..0000000
--- a/src/test/scripts/functions/misc/RewriteEMultChainOp.dml
+++ /dev/null
@@ -1,28 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-# 
-#   http://www.apache.org/licenses/LICENSE-2.0
-# 
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-
-X = read($1);
-Y = read($2);
-
-R = X * Y * X;
-
-write(R, $3);
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
new file mode 100644
index 0000000..6d94cc8
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")))
+Y = as.matrix(readMM(paste(args[1], "Y.mtx", sep="")))
+
+R = X * Y * X;
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/f005d949/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
new file mode 100644
index 0000000..3992403
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteEMultChainOpXYX.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+X = read($1);
+Y = read($2);
+
+R = X * Y * X;
+
+write(R, $3);
\ No newline at end of file