You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/10/23 04:04:23 UTC

[2/4] systemml git commit: [MINOR] Fix analysis of sparse-safeness for codegen cell/magg ops

[MINOR] Fix analysis of sparse-safeness for codegen cell/magg ops

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c70cb116
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c70cb116
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c70cb116

Branch: refs/heads/master
Commit: c70cb1166f4ec6c79d10248727a3eb7b85f70360
Parents: 78a3808
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun Oct 22 18:57:35 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sun Oct 22 18:57:35 2017 -0700

----------------------------------------------------------------------
 .../apache/sysml/hops/codegen/template/TemplateCell.java  | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/c70cb116/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index c9b0734..4f3d4f4 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -322,10 +322,12 @@ public class TemplateCell extends TemplateBase
 	protected boolean isSparseSafe(List<Hop> roots, Hop mainInput, List<CNode> outputs, List<AggOp> aggOps, boolean onlySum) {
 		boolean ret = true;
 		for( int i=0; i<outputs.size() && ret; i++ ) {
-			ret &= (HopRewriteUtils.isBinary(roots.get(i), OpOp2.MULT) 
-					&& roots.get(i).getInput().contains(mainInput))
-				|| (HopRewriteUtils.isBinary(roots.get(i), OpOp2.DIV) 
-					&& roots.get(i).getInput().get(0) == mainInput)
+			Hop root = (roots.get(i) instanceof AggUnaryOp || roots.get(i) 
+				instanceof AggBinaryOp) ? roots.get(i).getInput().get(0) : roots.get(i);
+			ret &= (HopRewriteUtils.isBinarySparseSafe(root) 
+					&& root.getInput().contains(mainInput))
+				|| (HopRewriteUtils.isBinary(root, OpOp2.DIV) 
+					&& root.getInput().get(0) == mainInput)
 				|| (TemplateUtils.rIsSparseSafeOnly(outputs.get(i), BinType.MULT)
 					&& TemplateUtils.rContainsInput(outputs.get(i), mainInput.getHopID()));
 			if( onlySum )