You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/06/18 19:14:38 UTC

[5/5] systemml git commit: [SYSTEMML-1543] Codegen multi-aggregates w/ matrix mult root nodes

[SYSTEMML-1543] Codegen multi-aggregates w/ matrix mult root nodes

This patch extends the compilation of codegen multi-aggregate templates
by support for dot products, which is important because sum(X^2) and
sum(X*Y) are rewritten to dot products by dynamic simplification
rewrites. Furthermore, this also includes minor fixes regarding indexing
under sum square operations, avoidance of fusion for 1x1 matrices, and
an unrelated fix for COO sparse blocks which came up after hash function
changes in SYSTEMML-1716.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c4342085
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c4342085
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c4342085

Branch: refs/heads/master
Commit: c43420855d0d768d8826adac455bc03ff673d23e
Parents: 23a164a
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sat Jun 17 23:22:37 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sun Jun 18 11:51:02 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/codegen/SpoofCompiler.java       |  5 +-
 .../template/PlanSelectionFuseCostBased.java    | 54 ++++++++++----------
 .../hops/codegen/template/TemplateCell.java     | 12 ++++-
 .../sysml/hops/rewrite/HopRewriteUtils.java     |  7 ++-
 .../runtime/matrix/data/SparseBlockCOO.java     |  2 +-
 .../functions/codegen/MultiAggTmplTest.java     | 18 ++++++-
 .../functions/codegen/multiAggPattern7.R        | 34 ++++++++++++
 .../functions/codegen/multiAggPattern7.dml      | 31 +++++++++++
 8 files changed, 132 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index a58e28d..988af7c 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -56,6 +56,7 @@ import org.apache.sysml.hops.codegen.template.CPlanMemoTable.MemoTableEntrySet;
 import org.apache.sysml.hops.codegen.template.TemplateUtils;
 import org.apache.sysml.hops.recompile.RecompileStatus;
 import org.apache.sysml.hops.recompile.Recompiler;
+import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.Hop.OpOp1;
 import org.apache.sysml.hops.HopsException;
@@ -631,7 +632,9 @@ public class SpoofCompiler
 					inHops[0].getRowsInBlock(), inHops[0].getColsInBlock(), -1);
 				//inject artificial right indexing operations for all parents of all nodes
 				for( int i=0; i<roots.size(); i++ ) {
-					Hop hnewi = HopRewriteUtils.createScalarIndexing(hnew, 1, i+1);
+					Hop hnewi = (roots.get(i) instanceof AggUnaryOp) ? 
+						HopRewriteUtils.createScalarIndexing(hnew, 1, i+1) :
+						HopRewriteUtils.createMatrixIndexing(hnew, 1, i+1);
 					HopRewriteUtils.rewireAllParentChildReferences(roots.get(i), hnewi);
 				}
 			}

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
index 0e301e8..e3435e5 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/PlanSelectionFuseCostBased.java
@@ -245,8 +245,7 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 		ArrayList<Long> fullAggs = new ArrayList<Long>();
 		for( Long hopID : R ) {
 			Hop root = memo._hopRefs.get(hopID);
-			if( !refHops.contains(hopID) && root instanceof AggUnaryOp 
-				&& ((AggUnaryOp)root).getDirection()==Direction.RowCol)
+			if( !refHops.contains(hopID) && isMultiAggregateRoot(root) )
 				fullAggs.add(hopID);
 		}
 		if( LOG.isTraceEnabled() ) {
@@ -306,10 +305,19 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 		for( Long hopID : fullAggs ) {
 			Hop aggHop = memo._hopRefs.get(hopID);
 			AggregateInfo tmp = new AggregateInfo(aggHop);
-			for( Hop c : aggHop.getInput() )
+			for( int i=0; i<aggHop.getInput().size(); i++ ) {
+				Hop c = HopRewriteUtils.isMatrixMultiply(aggHop) && i==0 ? 
+					aggHop.getInput().get(0).getInput().get(0) : aggHop.getInput().get(i);
 				rExtractAggregateInfo(memo, c, tmp, TemplateType.CellTpl);
-			if( tmp._fusedInputs.isEmpty() )
-				tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
+			}
+			if( tmp._fusedInputs.isEmpty() ) {
+				if( HopRewriteUtils.isMatrixMultiply(aggHop) ) {
+					tmp.addFusedInput(aggHop.getInput().get(0).getInput().get(0).getHopID());
+					tmp.addFusedInput(aggHop.getInput().get(1).getHopID());
+				}
+				else	
+					tmp.addFusedInput(aggHop.getInput().get(0).getHopID());
+			}
 			aggInfos.add(tmp);	
 		}
 		
@@ -319,10 +327,9 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 				LOG.trace(info);
 		}
 		
-		//filter aggregations w/ matmults to ensure consistent dims
 		//sort aggregations by num dependencies to simplify merging
 		//clusters of aggregations with parallel dependencies
-		aggInfos = aggInfos.stream().filter(a -> !a.containsMatMult)
+		aggInfos = aggInfos.stream()
 			.sorted(Comparator.comparing(a -> a._inputAggs.size()))
 			.collect(Collectors.toList());
 		
@@ -366,6 +373,13 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 		}
 	}
 	
+	private static boolean isMultiAggregateRoot(Hop root) {
+		return (HopRewriteUtils.isAggUnaryOp(root, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX) 
+				&& ((AggUnaryOp)root).getDirection()==Direction.RowCol)
+			|| (root instanceof AggBinaryOp && root.getDim1()==1 && root.getDim2()==1
+				&& HopRewriteUtils.isTransposeOperation(root.getInput().get(0)));
+	}
+	
 	private static boolean isValidMultiAggregate(CPlanMemoTable memo, MemoTableEntry me) {
 		//ensure input consistent sizes (otherwise potential for incorrect results)
 		boolean ret = true;
@@ -402,11 +416,8 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 			return;
 		
 		//collect all applicable full aggregations per read
-		if( HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX)
-			&& ((AggUnaryOp)current).getDirection()==Direction.RowCol )
-		{
+		if( isMultiAggregateRoot(current) )
 			aggs.add(current.getHopID());
-		}
 		
 		//recursively process children
 		for( Hop c : current.getInput() )
@@ -417,19 +428,12 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 	
 	private static void rExtractAggregateInfo(CPlanMemoTable memo, Hop current, AggregateInfo aggInfo, TemplateType type) {
 		//collect input aggregates (dependents)
-		if( HopRewriteUtils.isAggUnaryOp(current, AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX)
-			&& ((AggUnaryOp)current).getDirection()==Direction.RowCol )
-		{
+		if( isMultiAggregateRoot(current) )
 			aggInfo.addInputAggregate(current.getHopID());
-		}
-		
-		//collect included matrix multiplications
-		if( type != null && HopRewriteUtils.isMatrixMultiply(current) )
-			aggInfo.setContainsMatMult();
 		
 		//recursively process children
 		MemoTableEntry me = (type!=null) ? memo.getBest(current.getHopID()) : null;
-		for( int i=0; i< current.getInput().size(); i++ ) {
+		for( int i=0; i<current.getInput().size(); i++ ) {
 			Hop c = current.getInput().get(i);
 			if( me != null && me.isPlanRef(i) )
 				rExtractAggregateInfo(memo, c, aggInfo, type);
@@ -960,7 +964,6 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 		public final HashMap<Long,Hop> _aggregates;
 		public final HashSet<Long> _inputAggs = new HashSet<Long>();
 		public final HashSet<Long> _fusedInputs = new HashSet<Long>();
-		public boolean containsMatMult = false;
 		public AggregateInfo(Hop aggregate) {
 			_aggregates = new HashMap<Long, Hop>();
 			_aggregates.put(aggregate.getHopID(), aggregate);
@@ -971,9 +974,6 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 		public void addFusedInput(long hopID) {
 			_fusedInputs.add(hopID);
 		}
-		public void setContainsMatMult() {
-			containsMatMult = true;
-		}
 		public boolean isMergable(AggregateInfo that) {
 			//check independence
 			boolean ret = _aggregates.size()<3 
@@ -986,9 +986,11 @@ public class PlanSelectionFuseCostBased extends PlanSelection
 			ret &= !CollectionUtils.intersection(
 				_fusedInputs, that._fusedInputs).isEmpty();
 			//check consistent sizes (result correctness)
+			Hop in1 = _aggregates.values().iterator().next();
+			Hop in2 = that._aggregates.values().iterator().next();
 			return ret && HopRewriteUtils.isEqualSize(
-				_aggregates.values().iterator().next().getInput().get(0),
-				that._aggregates.values().iterator().next().getInput().get(0));
+				in1.getInput().get(HopRewriteUtils.isMatrixMultiply(in1)?1:0),
+				in2.getInput().get(HopRewriteUtils.isMatrixMultiply(in2)?1:0));
 		}
 		public AggregateInfo merge(AggregateInfo that) {
 			_aggregates.putAll(that._aggregates);

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index 91c61c2..5455775 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -76,6 +76,7 @@ public class TemplateCell extends TemplateBase
 	@Override
 	public boolean open(Hop hop) {
 		return hop.dimsKnown() && isValidOperation(hop)
+				&& !(hop.getDim1()==1 && hop.getDim2()==1) 	
 			|| (hop instanceof IndexingOp && (((IndexingOp)hop)
 				.isColLowerEqualsUpper() || hop.getDim2()==1));
 	}
@@ -162,6 +163,8 @@ public class TemplateCell extends TemplateBase
 			if( me!=null && me.isPlanRef(i) && !(c instanceof DataOp)
 				&& (me.type!=TemplateType.MultiAggTpl || memo.contains(c.getHopID(), TemplateType.CellTpl)))
 				rConstructCplan(c, memo, tmp, inHops, compileLiterals);
+			else if( me!=null && me.type==TemplateType.MultiAggTpl && HopRewriteUtils.isMatrixMultiply(hop) && i==0 )
+				rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals);
 			else {
 				CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);	
 				tmp.put(c.getHopID(), cdata);
@@ -233,7 +236,10 @@ public class TemplateCell extends TemplateBase
 		}
 		else if( HopRewriteUtils.isTransposeOperation(hop) ) 
 		{
-			out = tmp.get(hop.getInput().get(0).getHopID());	
+			out = TemplateUtils.skipTranspose(tmp.get(hop.getHopID()), 
+				hop, tmp, compileLiterals);
+			if( out instanceof CNodeData && !inHops.contains(hop.getInput().get(0)) )
+				inHops.add(hop.getInput().get(0));
 		}
 		else if( hop instanceof AggUnaryOp )
 		{
@@ -246,11 +252,15 @@ public class TemplateCell extends TemplateBase
 			//(1) t(X)%*%X -> sum(X^2) and t(X) %*% Y -> sum(X*Y)
 			if( HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1)) ) {
 				CNode cdata1 = tmp.get(hop.getInput().get(1).getHopID());
+				if( TemplateUtils.isColVector(cdata1) )
+					cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
 				out = new CNodeUnary(cdata1, UnaryType.POW2);
 			}
 			else {
 				CNode cdata1 = TemplateUtils.skipTranspose(tmp.get(hop.getInput().get(0).getHopID()), 
 						hop.getInput().get(0), tmp, compileLiterals);
+				if( cdata1 instanceof CNodeData && !inHops.contains(hop.getInput().get(0).getInput().get(0)) )
+					inHops.add(hop.getInput().get(0).getInput().get(0));
 				if( TemplateUtils.isColVector(cdata1) )
 					cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
 				CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index bec7b38..cf6081b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -532,13 +532,18 @@ public class HopRewriteUtils
 	}
 	
 	public static Hop createScalarIndexing(Hop input, long rix, long cix) {
+		Hop ix = createMatrixIndexing(input, rix, cix);
+		return createUnary(ix, OpOp1.CAST_AS_SCALAR);
+	}
+	
+	public static Hop createMatrixIndexing(Hop input, long rix, long cix) {
 		LiteralOp row = new LiteralOp(rix);
 		LiteralOp col = new LiteralOp(cix);
 		IndexingOp ix = new IndexingOp("tmp", DataType.MATRIX, ValueType.DOUBLE, input, row, row, col, col, true, true);
 		ix.setOutputBlocksizes(input.getRowsInBlock(), input.getColsInBlock());
 		copyLineNumbers(input, ix);
 		ix.refreshSizeInformation();
-		return createUnary(ix, OpOp1.CAST_AS_SCALAR);
+		return ix;
 	}
 	
 	public static Hop createValueHop( Hop hop, boolean row ) 

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java
index 9ca9418..9f527d4 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCOO.java
@@ -218,7 +218,7 @@ public class SparseBlockCOO extends SparseBlock
 	@Override
 	public int size(int r) {
 		int pos = pos(r);
-		if( _rindexes[pos]!=r )
+		if( pos>=_size || _rindexes[pos]!=r )
 			return 0;
 		
 		//count number of equal row indexes

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java
index c33d680..07a4396 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/MultiAggTmplTest.java
@@ -42,6 +42,7 @@ public class MultiAggTmplTest extends AutomatedTestBase
 	private static final String TEST_NAME4 = TEST_NAME+"4"; //sum(X*Y), sum(X^2), sum(Y^2)
 	private static final String TEST_NAME5 = TEST_NAME+"5"; //sum(V*X), sum(Y*Z), sum(X+Y-Z)
 	private static final String TEST_NAME6 = TEST_NAME+"6"; //min(X), max(X), sum(X)
+	private static final String TEST_NAME7 = TEST_NAME+"7"; //t(X)%*%X, t(X)%*Y, t(Y)%*%Y
 	
 	private static final String TEST_DIR = "functions/codegen/";
 	private static final String TEST_CLASS_DIR = TEST_DIR + MultiAggTmplTest.class.getSimpleName() + "/";
@@ -53,7 +54,7 @@ public class MultiAggTmplTest extends AutomatedTestBase
 	@Override
 	public void setUp() {
 		TestUtils.clearAssertionInformation();
-		for(int i=1; i<=6; i++)
+		for(int i=1; i<=7; i++)
 			addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) );
 	}
 	
@@ -147,6 +148,21 @@ public class MultiAggTmplTest extends AutomatedTestBase
 		testCodegenIntegration( TEST_NAME6, false, ExecType.SPARK );
 	}
 	
+	@Test	
+	public void testCodegenMultiAggRewrite7CP() {
+		testCodegenIntegration( TEST_NAME7, true, ExecType.CP );
+	}
+
+	@Test	
+	public void testCodegenMultiAgg7CP() {
+		testCodegenIntegration( TEST_NAME7, false, ExecType.CP );
+	}
+	
+	@Test	
+	public void testCodegenMultiAgg7Spark() {
+		testCodegenIntegration( TEST_NAME7, false, ExecType.SPARK );
+	}
+	
 	private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
 	{	
 		boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/test/scripts/functions/codegen/multiAggPattern7.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/multiAggPattern7.R b/src/test/scripts/functions/codegen/multiAggPattern7.R
new file mode 100644
index 0000000..b56f090
--- /dev/null
+++ b/src/test/scripts/functions/codegen/multiAggPattern7.R
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = seq(1,15);
+Y = seq(2,16);
+
+r1 = t(X)%*%X;
+r2 = t(X)%*%Y;
+r3 = t(Y)%*%Y;
+S = r1+r2+r3;
+
+writeMM(as(S, "CsparseMatrix"), paste(args[2], "S", sep="")); 

http://git-wip-us.apache.org/repos/asf/systemml/blob/c4342085/src/test/scripts/functions/codegen/multiAggPattern7.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/multiAggPattern7.dml b/src/test/scripts/functions/codegen/multiAggPattern7.dml
new file mode 100644
index 0000000..3306fd8
--- /dev/null
+++ b/src/test/scripts/functions/codegen/multiAggPattern7.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = seq(1,15);
+Y = seq(2,16);
+if(1==1){}
+
+r1 = t(X)%*%X;
+r2 = t(X)%*%Y;
+r3 = t(Y)%*%Y;
+S = r1+r2+r3;
+
+write(S,$1)