You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ma...@apache.org on 2022/04/20 11:41:41 UTC

[systemds] branch main updated: [SYSTEMDS-3334] Codegen RowMaxs_VectMult rewrite

This is an automated email from the ASF dual-hosted git repository.

markd pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 811e3f474c [SYSTEMDS-3334] Codegen RowMaxs_VectMult rewrite
811e3f474c is described below

commit 811e3f474c7e4e1747e7b5e54ffa75e79afc1cd5
Author: Mark Dokter <ma...@dokter.cc>
AuthorDate: Tue Apr 19 23:16:52 2022 +0200

    [SYSTEMDS-3334] Codegen RowMaxs_VectMult rewrite
    
    This rewrite fuses a vector multiplication with a row max aggregation to avoid an intermediate vector in Spoof's row template. Occurs when using code gen in components.dml.
    
    Closes #1566
---
 .../apache/sysds/hops/codegen/SpoofCompiler.java   | 15 ++++++++------
 .../sysds/hops/codegen/cplan/CNodeBinary.java      |  7 ++++++-
 .../sysds/hops/codegen/cplan/java/Binary.java      |  3 +++
 .../hops/codegen/template/CPlanOpRewriter.java     | 19 +++++++++++++++--
 .../sysds/hops/codegen/template/TemplateUtils.java | 24 ++++++++++++++++++++--
 .../sysds/runtime/codegen/LibSpoofPrimitives.java  | 18 ++++++++++++++--
 6 files changed, 73 insertions(+), 13 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
index ade88775e1..55d75b092a 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
@@ -38,6 +38,7 @@ import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.LiteralOp;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.codegen.cplan.CNode;
+import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
 import org.apache.sysds.hops.codegen.cplan.CNodeCell;
 import org.apache.sysds.hops.codegen.cplan.CNodeData;
 import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
@@ -941,13 +942,15 @@ public class SpoofCompiler {
 			}
 			
 			//remove cplan w/ single op and w/o agg
-			if( (tpl instanceof CNodeCell && ((CNodeCell)tpl).getCellType()==CellType.NO_AGG
-					&& TemplateUtils.hasSingleOperation(tpl) )
-				|| (tpl instanceof CNodeRow && (((CNodeRow)tpl).getRowType()==RowType.NO_AGG
-					|| ((CNodeRow)tpl).getRowType()==RowType.NO_AGG_B1
-					|| ((CNodeRow)tpl).getRowType()==RowType.ROW_AGG )
+			if((tpl instanceof CNodeCell && ((CNodeCell)tpl).getCellType()==CellType.NO_AGG
 					&& TemplateUtils.hasSingleOperation(tpl))
-				|| TemplateUtils.hasNoOperation(tpl) ) 
+				|| (tpl instanceof CNodeRow
+					&& (((CNodeRow)tpl).getRowType()==RowType.NO_AGG
+						|| ((CNodeRow)tpl).getRowType()==RowType.NO_AGG_B1
+						|| (((CNodeRow)tpl).getRowType()==RowType.ROW_AGG  && !TemplateUtils.isBinary(tpl.getOutput(),
+							CNodeBinary.BinType.ROWMAXS_VECTMULT)))
+					&& TemplateUtils.hasSingleOperation(tpl))
+				|| TemplateUtils.hasNoOperation(tpl))
 			{
 				cplans2.remove(e.getKey());
 				if( LOG.isTraceEnabled() )
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
index 2e6bcd5d48..bebf0a221b 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
@@ -30,6 +30,8 @@ import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
 public class CNodeBinary extends CNode {
 
 	public enum BinType {
+		// Fused vect_op + aggregation
+		ROWMAXS_VECTMULT,
 		//matrix multiplication operations
 		DOT_PRODUCT, VECT_MATRIXMULT, VECT_OUTERMULT_ADD,
 		//vector-scalar-add operations
@@ -373,7 +375,8 @@ public class CNodeBinary extends CNode {
 				_cols = _inputs.get(1)._cols;
 				_dataType = DataType.MATRIX;
 				break;
-			
+
+			case ROWMAXS_VECTMULT:
 			case DOT_PRODUCT:
 			
 			//SCALAR Arithmetic
@@ -407,6 +410,8 @@ public class CNodeBinary extends CNode {
 				_cols = 0;
 				_dataType= DataType.SCALAR;
 				break;
+			default:
+					throw new RuntimeException("Unknown CNodeBinary type: " + _type);
 		}
 	}
 	
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
index ecb7878f66..40496249e5 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
@@ -28,6 +28,9 @@ public class Binary extends CodeTemplate {
 		boolean scalarVector, boolean scalarInput, boolean vectorVector)
 	{
 		switch (type) {
+			case ROWMAXS_VECTMULT:
+				return sparseLhs ? "\tdouble %TMP% = LibSpoofPrimitives.rowMaxsVectMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
+						"\tdouble %TMP% = LibSpoofPrimitives.rowMaxsVectMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
 			case DOT_PRODUCT:
 				return sparseLhs ? "    double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
 						"    double %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java b/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java
index 2b981ee893..b81ddac401 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java
@@ -21,11 +21,14 @@ package org.apache.sysds.hops.codegen.template;
 
 import java.util.ArrayList;
 
+import org.apache.spark.sql.types.BinaryType;
 import org.apache.sysds.hops.LiteralOp;
 import org.apache.sysds.hops.codegen.cplan.CNode;
+import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
 import org.apache.sysds.hops.codegen.cplan.CNodeData;
 import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
 import org.apache.sysds.hops.codegen.cplan.CNodeOuterProduct;
+import org.apache.sysds.hops.codegen.cplan.CNodeRow;
 import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
 import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
 import org.apache.sysds.hops.codegen.cplan.CNodeBinary.BinType;
@@ -56,6 +59,9 @@ public class CPlanOpRewriter
 		}
 		else {
 			tpl.setOutput(rSimplifyCNode(tpl.getOutput()));
+			if(TemplateUtils.containsFusedRowVecAgg(tpl)) {
+				((CNodeRow) tpl).setNumVectorIntermediates(((CNodeRow) tpl).getNumVectorIntermediates()-2);
+			}
 		}
 		
 		return tpl;
@@ -73,10 +79,19 @@ public class CPlanOpRewriter
 		node = rewriteBinaryPow2Vect(node);  //X^2 -> X*X
 		node = rewriteBinaryMult2(node);     //x*2 -> x+x;
 		node = rewriteBinaryMult2Vect(node); //X*2 -> X+X;
-		
+		node = rewriteRowMaxsVectMult(node); // rowMaxs(G * t(c)); see components.dml
 		return node;
 	}
-	
+
+	private static CNode rewriteRowMaxsVectMult(CNode node) {
+		if(TemplateUtils.isUnary(node, UnaryType.ROW_MAXS)) {
+			CNode input = node.getInput().get(0);
+			if(TemplateUtils.isBinary(input, BinType.VECT_MULT))
+				return new CNodeBinary(input.getInput().get(0), input.getInput().get(1), BinType.ROWMAXS_VECTMULT);
+		}
+		return node;
+	}
+
 	private static CNode rewriteRowCountNnz(CNode node) {
 		return (TemplateUtils.isUnary(node, UnaryType.ROW_SUMS)
 			&& TemplateUtils.isBinary(node.getInput().get(0), BinType.VECT_NOTEQUAL_SCALAR)
diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
index f61305fa13..8a4e0f62c9 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
@@ -49,6 +49,7 @@ import org.apache.sysds.hops.codegen.cplan.CNode;
 import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
 import org.apache.sysds.hops.codegen.cplan.CNodeData;
 import org.apache.sysds.hops.codegen.cplan.CNodeNary;
+import org.apache.sysds.hops.codegen.cplan.CNodeRow;
 import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
 import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
 import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
@@ -279,7 +280,11 @@ public class TemplateUtils
 		return node instanceof CNodeUnary
 			&& ArrayUtils.contains(types, ((CNodeUnary)node).getType());
 	}
-	
+
+	public static boolean isUnaryRowAgg(CNode node) {
+		return isUnary(node, UnaryType.ROW_MAXS, UnaryType.ROW_SUMS);
+	}
+
 	public static boolean isBinary(CNode node, BinType...types) {
 		return node instanceof CNodeBinary
 			&& ArrayUtils.contains(types, ((CNodeBinary)node).getType());
@@ -391,7 +396,8 @@ public class TemplateUtils
 				&& !TemplateUtils.isUnary(output, 
 					UnaryType.EXP, UnaryType.LOG, UnaryType.ROW_COUNTNNZS)) 
 			|| (output instanceof CNodeBinary
-				&& !TemplateUtils.isBinary(output, BinType.VECT_OUTERMULT_ADD))
+				&& (!(TemplateUtils.isBinary(output, BinType.VECT_OUTERMULT_ADD) ||
+					!TemplateUtils.isBinary(output, BinType.ROWMAXS_VECTMULT))))
 			|| output instanceof CNodeTernary 
 				&& ((CNodeTernary)output).getType() == TernaryType.IFELSE)
 			&& hasOnlyDataNodeOrLookupInputs(output);
@@ -687,4 +693,18 @@ public class TemplateUtils
 		for( CNode input : current.getInput() )
 			rFlipVectorLookups(input);
 	}
+
+	public static boolean containsFusedRowVecAgg(CNodeTpl tpl) {
+		if(!(tpl instanceof CNodeRow))
+			return false;
+
+		if(TemplateUtils.isBinary(tpl.getOutput(), BinType.ROWMAXS_VECTMULT))
+			return true;
+
+		for (CNode n : tpl.getOutput().getInput()) {
+			if(TemplateUtils.isBinary(n, BinType.ROWMAXS_VECTMULT))
+				return true;
+		}
+		return false;
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
index 905b39226d..c618e79607 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
@@ -50,9 +50,23 @@ public class LibSpoofPrimitives
 	private static ThreadLocal<VectorBuffer> memPool = new ThreadLocal<VectorBuffer>() {
 		@Override protected VectorBuffer initialValue() { return new VectorBuffer(0,0,0); }
 	};
-	
+
+	public static double rowMaxsVectMult(double[] a, double[] b, int ai, int bi, int len) {
+		double val = Double.NEGATIVE_INFINITY;
+		int j=0;
+		for( int i = ai; i < ai+len; i++ )
+			val = Math.max(a[i]*b[j++], val);
+		return val;
+	}
+
+	public static double rowMaxsVectMult(double[] a, double[] b, int[] aix, int ai, int bi, int len) {
+		double val = Double.NEGATIVE_INFINITY;
+		for( int i = ai; i < ai+len; i++ )
+			val = Math.max(a[i]*b[aix[i]], val);
+		return val;
+	}
+
 	// forwarded calls to LibMatrixMult
-	
 	public static double dotProduct(double[] a, double[] b, int ai, int bi, int len) {
 		if( a == null || b == null ) return 0;
 		return LibMatrixMult.dotProduct(a, b, ai, bi, len);