You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/11/10 06:06:54 UTC

[3/4] systemml git commit: [SYSTEMML-1990] Generalized ctable rewrites (seq-table, const inputs)

[SYSTEMML-1990] Generalized ctable rewrites (seq-table, const inputs)

This patch generalized the existing rewrite for table(seq(),X,...) to
rexpand(X,...) to handle cases with unknown dimensions, including common
scenarios with column indexing on X. Additionally, this patch also
introduces a new rewrite for table with constant matrix inputs (i.e.,
table(X, matrix(7)) -> table(X,7)).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c9614324
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c9614324
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c9614324

Branch: refs/heads/master
Commit: c96143248349b6c68253ef9b3777afd5e5ed62f2
Parents: d696862
Author: Matthias Boehm <mb...@gmail.com>
Authored: Thu Nov 9 16:31:58 2017 -0800
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Thu Nov 9 22:08:02 2017 -0800

----------------------------------------------------------------------
 .../sysml/hops/rewrite/HopRewriteUtils.java     | 27 ++++++-
 .../RewriteAlgebraicSimplificationDynamic.java  | 11 ++-
 .../RewriteAlgebraicSimplificationStatic.java   | 22 +++++-
 .../misc/RewriteCTableToRExpandTest.java        | 83 ++++++++++++++------
 .../RewriteCTableToRExpandLeftUnknownPos.dml    | 28 +++++++
 .../RewriteCTableToRExpandRightUnknownPos.dml   | 28 +++++++
 6 files changed, 167 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 28b2189..66f4fc7 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -965,6 +965,15 @@ public class HopRewriteUtils
 			|| isLiteralOfValue(hop.getInput().get(1), val));
 	}
 	
+	public static boolean isTernary(Hop hop, OpOp3 type) {
+		return hop instanceof TernaryOp && ((TernaryOp)hop).getOp()==type;
+	}
+	
+	public static boolean isTernary(Hop hop, OpOp3... types) {
+		return ( hop instanceof TernaryOp 
+			&& ArrayUtils.contains(types, ((TernaryOp) hop).getOp()));
+	}
+	
 	public static boolean containsInput(Hop current, Hop probe) {
 		return rContainsInput(current, probe, new HashSet<Long>());	
 	}
@@ -1052,6 +1061,15 @@ public class HopRewriteUtils
 		return true;
 	}
 	
+	public static boolean isColumnRightIndexing(Hop hop) {
+		return hop instanceof IndexingOp
+			&& ((IndexingOp) hop).isColLowerEqualsUpper()
+			&& ((hop.dimsKnown() && hop.getDim1() == hop.getInput().get(0).getDim1())
+			|| (isLiteralOfValue(hop.getInput().get(1), 1) 
+				&& isUnary(hop.getInput().get(2), OpOp1.NROW) 
+				&& hop.getInput().get(2).getInput().get(0)==hop.getInput().get(0)));
+	}
+	
 	public static boolean isFullColumnIndexing(LeftIndexingOp hop) {
 		return hop.isColLowerEqualsUpper()
 			&& isLiteralOfValue(hop.getInput().get(2), 1)
@@ -1112,9 +1130,7 @@ public class HopRewriteUtils
 			Hop to = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_TO));
 			Hop incr = dgop.getInput().get(dgop.getParamIndex(Statement.SEQ_INCR));
 			return isLiteralOfValue(from, 1) && isLiteralOfValue(incr, 1)
-				&& (isLiteralOfValue(to, row?input.getDim1():input.getDim2())
-					|| (to instanceof UnaryOp && ((UnaryOp)to).getOp()==(row?
-						OpOp1.NROW:OpOp1.NCOL) && to.getInput().get(0)==input));
+				&& isSizeExpressionOf(to, input, row);
 		}
 		return false;
 	}
@@ -1149,6 +1165,11 @@ public class HopRewriteUtils
 		throw new HopsException("Failed to retrieve 'to' argument from basic 1-N sequence.");
 	}
 	
+	public static boolean isSizeExpressionOf(Hop size, Hop input, boolean row) {
+		return (input.dimsKnown() && isLiteralOfValue(size, row?input.getDim1():input.getDim2()))
+			|| ((row ? isUnary(size, OpOp1.NROW) : isUnary(size, OpOp1.NCOL)) && (size.getInput().get(0)==input 
+			|| (isColumnRightIndexing(input) && size.getInput().get(0)==input.getInput().get(0))));
+	}
 	
 	public static boolean hasOnlyWriteParents( Hop hop, boolean inclTransient, boolean inclPersistent )
 	{

http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 0fa1aed..e07f97c 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -2540,15 +2540,14 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 		//pattern: table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true)
 		//note: this rewrite supports both left/right sequence 
 		if(    hi instanceof TernaryOp && hi.getInput().size()==5 //table without weights 
-			&& HopRewriteUtils.isLiteralOfValue(hi.getInput().get(2), 1) //i.e., weight of 1
-			&& hi.getInput().get(3) instanceof LiteralOp && hi.getInput().get(4) instanceof LiteralOp)
+			&& HopRewriteUtils.isLiteralOfValue(hi.getInput().get(2), 1) ) //i.e., weight of 1
 		{
 			Hop first = hi.getInput().get(0);
 			Hop second = hi.getInput().get(1);
 			
 			//pattern a: table(seq(1,nrow(v)), v, nrow(v), m, 1)
-			if( HopRewriteUtils.isBasic1NSequence(first, second, true) && second.dimsKnown() 
-				&& HopRewriteUtils.isLiteralOfValue(hi.getInput().get(3), second.getDim1()) )
+			if( HopRewriteUtils.isBasic1NSequence(first, second, true) 
+				&& HopRewriteUtils.isSizeExpressionOf(hi.getInput().get(3), second, true) )
 			{
 				//setup input parameter hops
 				HashMap<String,Hop> args = new HashMap<>();
@@ -2568,8 +2567,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
 				LOG.debug("Applied simplifyTableSeqExpand1 (line "+hi.getBeginLine()+")");	
 			}
 			//pattern b: table(v, seq(1,nrow(v)), m, nrow(v))
-			else if( HopRewriteUtils.isBasic1NSequence(second, first, true) && first.dimsKnown() 
-				&& HopRewriteUtils.isLiteralOfValue(hi.getInput().get(4), first.getDim1()) )
+			else if( HopRewriteUtils.isBasic1NSequence(second, first, true)
+				&& HopRewriteUtils.isSizeExpressionOf(hi.getInput().get(4), first, true) )
 			{
 				//setup input parameter hops
 				HashMap<String,Hop> args = new HashMap<>();

http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 4c68fe2..cbfb527 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -152,6 +152,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			hi = foldMultipleAppendOperations(hi);               //e.g., cbind(X,cbind(Y,Z)) -> cbind(X,Y,Z)
 			hi = simplifyBinaryToUnaryOperation(hop, hi, i);     //e.g., X*X -> X^2 (pow2), X+X -> X*2, (X>0)-(X<0) -> sign(X)
 			hi = canonicalizeMatrixMultScalarAdd(hi);            //e.g., eps+U%*%t(V) -> U%*%t(V)+eps, U%*%t(V)-eps -> U%*%t(V)+(-eps) 
+			hi = simplifyCTableWithConstMatrixInputs(hi);        //e.g., table(X, matrix(1,...)) -> table(X, 1)
 			hi = simplifyReverseOperation(hop, hi, i);           //e.g., table(seq(1,nrow(X),1),seq(nrow(X),1,-1)) %*% X -> rev(X)
 			if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
 				hi = simplifyMultiBinaryToBinaryOperation(hi);       //e.g., 1-X*Y -> X 1-* Y
@@ -664,13 +665,32 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			{
 				bop.setOp(OpOp2.PLUS);
 				HopRewriteUtils.replaceChildReference(bop,  right,
-						HopRewriteUtils.createBinaryMinus(right), 1);				
+						HopRewriteUtils.createBinaryMinus(right), 1);
 				LOG.debug("Applied canonicalizeMatrixMultScalarAdd2 (line "+hi.getBeginLine()+").");
 			}
 		}
 		
 		return hi;
 	}
+	
+	private static Hop simplifyCTableWithConstMatrixInputs( Hop hi ) 
+		throws HopsException
+	{
+		//pattern: table(X, matrix(1,...), matrix(7, ...)) -> table(X, 1, 7)
+		if( HopRewriteUtils.isTernary(hi, OpOp3.CTABLE) ) {
+			//note: the first input always expected to be a matrix
+			for( int i=1; i<hi.getInput().size(); i++ ) {
+				Hop inCurr = hi.getInput().get(i);
+				if( HopRewriteUtils.isDataGenOpWithConstantValue(inCurr) ) {
+					Hop inNew = ((DataGenOp)inCurr).getInput(DataExpression.RAND_MIN);
+					HopRewriteUtils.replaceChildReference(hi, inCurr, inNew, i);
+					LOG.debug("Applied simplifyCTableWithConstMatrixInputs"
+						+ i + " (line "+hi.getBeginLine()+").");
+				}
+			}
+		}
+		return hi;
+	}
 
 	/**
 	 * NOTE: this would be by definition a dynamic rewrite; however, we apply it as a static

http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java
index b42a978..838fbb1 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteCTableToRExpandTest.java
@@ -22,6 +22,7 @@ package org.apache.sysml.test.integration.functions.misc;
 import org.junit.Test;
 
 import org.junit.Assert;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
@@ -33,6 +34,8 @@ public class RewriteCTableToRExpandTest extends AutomatedTestBase
 	private static final String TEST_NAME2 = "RewriteCTableToRExpandRightPos"; 
 	private static final String TEST_NAME3 = "RewriteCTableToRExpandLeftNeg"; 
 	private static final String TEST_NAME4 = "RewriteCTableToRExpandRightNeg"; 
+	private static final String TEST_NAME5 = "RewriteCTableToRExpandLeftUnknownPos";
+	private static final String TEST_NAME6 = "RewriteCTableToRExpandRightUnknownPos";
 	
 	private static final String TEST_DIR = "functions/misc/";
 	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteCTableToRExpandTest.class.getSimpleName() + "/";
@@ -52,6 +55,8 @@ public class RewriteCTableToRExpandTest extends AutomatedTestBase
 		addTestConfiguration( TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "R" }) );
 		addTestConfiguration( TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] { "R" }) );
 		addTestConfiguration( TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] { "R" }) );
+		addTestConfiguration( TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] { "R" }) );
+		addTestConfiguration( TEST_NAME6, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME6, new String[] { "R" }) );
 	}
 
 	@Test
@@ -94,6 +99,25 @@ public class RewriteCTableToRExpandTest extends AutomatedTestBase
 		testRewriteCTableRExpand( TEST_NAME4, CropType.PAD );
 	}
 	
+	@Test
+	public void testRewriteCTableRExpandLeftUnknownDenseCrop()  {
+		testRewriteCTableRExpand( TEST_NAME5, CropType.CROP );
+	}
+	
+	@Test
+	public void testRewriteCTableRExpandLeftUnknownDensePad()  {
+		testRewriteCTableRExpand( TEST_NAME5, CropType.PAD );
+	}
+	
+	@Test
+	public void testRewriteCTableRExpandRightUnknownDenseCrop()  {
+		testRewriteCTableRExpand( TEST_NAME6, CropType.CROP );
+	}
+	
+	@Test
+	public void testRewriteCTableRExpandRightUnknownDensePad()  {
+		testRewriteCTableRExpand( TEST_NAME6, CropType.PAD );
+	}
 	
 	private void testRewriteCTableRExpand( String testname, CropType type )
 	{	
@@ -101,30 +125,45 @@ public class RewriteCTableToRExpandTest extends AutomatedTestBase
 		loadTestConfiguration(config);
 
 		int outDim = maxVal + ((type==CropType.CROP) ? -7 : 7);
+		boolean unknownTests = ( testname.equals(TEST_NAME5) || testname.equals(TEST_NAME6) );
+			
 		
-		String HOME = SCRIPT_DIR + TEST_DIR;
-		fullDMLScriptName = HOME + testname + ".dml";
-		programArgs = new String[]{ "-stats","-args", 
-			input("A"), String.valueOf(outDim), output("R") };
-		
-		fullRScriptName = HOME + testname + ".R";
-		rCmd = getRCmd(inputDir(), String.valueOf(outDim), expectedDir());			
-
-		double[][] A = getRandomMatrix(rows, 1, 1, 10, 1.0, 7);
-		writeInputMatrixWithMTD("A", A, false);
-		
-		//run performance tests
-		runTest(true, false, null, -1); 
+		RUNTIME_PLATFORM platformOld = rtplatform;
+		if( unknownTests )
+			rtplatform = RUNTIME_PLATFORM.SINGLE_NODE;
 		
-		//compare output meta data
-		boolean left = (testname.equals(TEST_NAME1) || testname.equals(TEST_NAME3));
-		boolean pos = (testname.equals(TEST_NAME1) || testname.equals(TEST_NAME2));
-		int rrows = (left && pos) ? rows : outDim;
-		int rcols = (!left && pos) ? rows : outDim;
-		checkDMLMetaDataFile("R", new MatrixCharacteristics(rrows, rcols, 1, 1));
-		
-		//check for applied rewrite
-		Assert.assertEquals(Boolean.valueOf(testname.equals(TEST_NAME1) || testname.equals(TEST_NAME2)),
+		try 
+		{
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + testname + ".dml";
+			programArgs = new String[]{ "-explain","-stats","-args", 
+				input("A"), String.valueOf(outDim), output("R") };
+			
+			fullRScriptName = HOME + testname + ".R";
+			rCmd = getRCmd(inputDir(), String.valueOf(outDim), expectedDir());
+	
+			double[][] A = getRandomMatrix(rows, 1, 1, 10, 1.0, 7);
+			writeInputMatrixWithMTD("A", A, false);
+			
+			//run performance tests
+			runTest(true, false, null, -1); 
+			
+			//compare output meta data
+			boolean left = (testname.equals(TEST_NAME1) || testname.equals(TEST_NAME3) 
+				|| testname.equals(TEST_NAME5) || testname.equals(TEST_NAME6));
+			boolean pos = (testname.equals(TEST_NAME1) || testname.equals(TEST_NAME2));
+			int rrows = (left && pos) ? rows : outDim;
+			int rcols = (!left && pos) ? rows : outDim;
+			if( !unknownTests )
+				checkDMLMetaDataFile("R", new MatrixCharacteristics(rrows, rcols, 1, 1));
+			
+			//check for applied rewrite
+			Assert.assertEquals(Boolean.valueOf(testname.equals(TEST_NAME1) 
+				|| testname.equals(TEST_NAME2) || unknownTests),
 				Boolean.valueOf(heavyHittersContainsSubString("rexpand")));
+		}
+		finally {
+			rtplatform = platformOld;
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/test/scripts/functions/misc/RewriteCTableToRExpandLeftUnknownPos.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCTableToRExpandLeftUnknownPos.dml b/src/test/scripts/functions/misc/RewriteCTableToRExpandLeftUnknownPos.dml
new file mode 100644
index 0000000..4b07462
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCTableToRExpandLeftUnknownPos.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+
+T = matrix(1, nrow(A), 2);
+A2 = rand(rows=sum(T)/2, cols=100, min=1, max=10);
+R = table(seq(1,nrow(A2)), A2[,1], nrow(A2), $2);
+
+write(R, $3);

http://git-wip-us.apache.org/repos/asf/systemml/blob/c9614324/src/test/scripts/functions/misc/RewriteCTableToRExpandRightUnknownPos.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteCTableToRExpandRightUnknownPos.dml b/src/test/scripts/functions/misc/RewriteCTableToRExpandRightUnknownPos.dml
new file mode 100644
index 0000000..68d2860
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteCTableToRExpandRightUnknownPos.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+
+T = matrix(1, nrow(A), 2);
+A2 = rand(rows=sum(T)/2, cols=100, min=1, max=10);
+R = table(A2[,1], seq(1,nrow(A2)), $2, nrow(A2));
+
+write(R, $3);