You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ma...@apache.org on 2022/04/20 11:41:41 UTC
[systemds] branch main updated: [SYSTEMDS-3334] Codegen RowMaxs_VectMult rewrite
This is an automated email from the ASF dual-hosted git repository.
markd pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 811e3f474c [SYSTEMDS-3334] Codegen RowMaxs_VectMult rewrite
811e3f474c is described below
commit 811e3f474c7e4e1747e7b5e54ffa75e79afc1cd5
Author: Mark Dokter <ma...@dokter.cc>
AuthorDate: Tue Apr 19 23:16:52 2022 +0200
[SYSTEMDS-3334] Codegen RowMaxs_VectMult rewrite
This rewrite fuses a vector multiplication with a row max aggregation to avoid an intermediate vector in Spoof's row template. Occurs when using code gen in components.dml.
Closes #1566
---
.../apache/sysds/hops/codegen/SpoofCompiler.java | 15 ++++++++------
.../sysds/hops/codegen/cplan/CNodeBinary.java | 7 ++++++-
.../sysds/hops/codegen/cplan/java/Binary.java | 3 +++
.../hops/codegen/template/CPlanOpRewriter.java | 19 +++++++++++++++--
.../sysds/hops/codegen/template/TemplateUtils.java | 24 ++++++++++++++++++++--
.../sysds/runtime/codegen/LibSpoofPrimitives.java | 18 ++++++++++++++--
6 files changed, 73 insertions(+), 13 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
index ade88775e1..55d75b092a 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
@@ -38,6 +38,7 @@ import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.codegen.cplan.CNode;
+import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeCell;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
@@ -941,13 +942,15 @@ public class SpoofCompiler {
}
//remove cplan w/ single op and w/o agg
- if( (tpl instanceof CNodeCell && ((CNodeCell)tpl).getCellType()==CellType.NO_AGG
- && TemplateUtils.hasSingleOperation(tpl) )
- || (tpl instanceof CNodeRow && (((CNodeRow)tpl).getRowType()==RowType.NO_AGG
- || ((CNodeRow)tpl).getRowType()==RowType.NO_AGG_B1
- || ((CNodeRow)tpl).getRowType()==RowType.ROW_AGG )
+ if((tpl instanceof CNodeCell && ((CNodeCell)tpl).getCellType()==CellType.NO_AGG
&& TemplateUtils.hasSingleOperation(tpl))
- || TemplateUtils.hasNoOperation(tpl) )
+ || (tpl instanceof CNodeRow
+ && (((CNodeRow)tpl).getRowType()==RowType.NO_AGG
+ || ((CNodeRow)tpl).getRowType()==RowType.NO_AGG_B1
+ || (((CNodeRow)tpl).getRowType()==RowType.ROW_AGG && !TemplateUtils.isBinary(tpl.getOutput(),
+ CNodeBinary.BinType.ROWMAXS_VECTMULT)))
+ && TemplateUtils.hasSingleOperation(tpl))
+ || TemplateUtils.hasNoOperation(tpl))
{
cplans2.remove(e.getKey());
if( LOG.isTraceEnabled() )
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
index 2e6bcd5d48..bebf0a221b 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java
@@ -30,6 +30,8 @@ import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
public class CNodeBinary extends CNode {
public enum BinType {
+ // Fused vect_op + aggregation
+ ROWMAXS_VECTMULT,
//matrix multiplication operations
DOT_PRODUCT, VECT_MATRIXMULT, VECT_OUTERMULT_ADD,
//vector-scalar-add operations
@@ -373,7 +375,8 @@ public class CNodeBinary extends CNode {
_cols = _inputs.get(1)._cols;
_dataType = DataType.MATRIX;
break;
-
+
+ case ROWMAXS_VECTMULT:
case DOT_PRODUCT:
//SCALAR Arithmetic
@@ -407,6 +410,8 @@ public class CNodeBinary extends CNode {
_cols = 0;
_dataType= DataType.SCALAR;
break;
+ default:
+ throw new RuntimeException("Unknown CNodeBinary type: " + _type);
}
}
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
index ecb7878f66..40496249e5 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Binary.java
@@ -28,6 +28,9 @@ public class Binary extends CodeTemplate {
boolean scalarVector, boolean scalarInput, boolean vectorVector)
{
switch (type) {
+ case ROWMAXS_VECTMULT:
+ return sparseLhs ? "\tdouble %TMP% = LibSpoofPrimitives.rowMaxsVectMult(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
+ "\tdouble %TMP% = LibSpoofPrimitives.rowMaxsVectMult(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
case DOT_PRODUCT:
return sparseLhs ? " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
" double %TMP% = LibSpoofPrimitives.dotProduct(%IN1%, %IN2%, %POS1%, %POS2%, %LEN%);\n";
diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java b/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java
index 2b981ee893..b81ddac401 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/CPlanOpRewriter.java
@@ -21,11 +21,14 @@ package org.apache.sysds.hops.codegen.template;
import java.util.ArrayList;
+import org.apache.spark.sql.types.BinaryType;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.codegen.cplan.CNode;
+import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
import org.apache.sysds.hops.codegen.cplan.CNodeOuterProduct;
+import org.apache.sysds.hops.codegen.cplan.CNodeRow;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary.BinType;
@@ -56,6 +59,9 @@ public class CPlanOpRewriter
}
else {
tpl.setOutput(rSimplifyCNode(tpl.getOutput()));
+ if(TemplateUtils.containsFusedRowVecAgg(tpl)) {
+ ((CNodeRow) tpl).setNumVectorIntermediates(((CNodeRow) tpl).getNumVectorIntermediates()-2);
+ }
}
return tpl;
@@ -73,10 +79,19 @@ public class CPlanOpRewriter
node = rewriteBinaryPow2Vect(node); //X^2 -> X*X
node = rewriteBinaryMult2(node); //x*2 -> x+x;
node = rewriteBinaryMult2Vect(node); //X*2 -> X+X;
-
+ node = rewriteRowMaxsVectMult(node); // rowMaxs(G * t(c)); see components.dml
return node;
}
-
+
+ private static CNode rewriteRowMaxsVectMult(CNode node) {
+ if(TemplateUtils.isUnary(node, UnaryType.ROW_MAXS)) {
+ CNode input = node.getInput().get(0);
+ if(TemplateUtils.isBinary(input, BinType.VECT_MULT))
+ return new CNodeBinary(input.getInput().get(0), input.getInput().get(1), BinType.ROWMAXS_VECTMULT);
+ }
+ return node;
+ }
+
private static CNode rewriteRowCountNnz(CNode node) {
return (TemplateUtils.isUnary(node, UnaryType.ROW_SUMS)
&& TemplateUtils.isBinary(node.getInput().get(0), BinType.VECT_NOTEQUAL_SCALAR)
diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
index f61305fa13..8a4e0f62c9 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
@@ -49,6 +49,7 @@ import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeNary;
+import org.apache.sysds.hops.codegen.cplan.CNodeRow;
import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
@@ -279,7 +280,11 @@ public class TemplateUtils
return node instanceof CNodeUnary
&& ArrayUtils.contains(types, ((CNodeUnary)node).getType());
}
-
+
+ public static boolean isUnaryRowAgg(CNode node) {
+ return isUnary(node, UnaryType.ROW_MAXS, UnaryType.ROW_SUMS);
+ }
+
public static boolean isBinary(CNode node, BinType...types) {
return node instanceof CNodeBinary
&& ArrayUtils.contains(types, ((CNodeBinary)node).getType());
@@ -391,7 +396,8 @@ public class TemplateUtils
&& !TemplateUtils.isUnary(output,
UnaryType.EXP, UnaryType.LOG, UnaryType.ROW_COUNTNNZS))
|| (output instanceof CNodeBinary
- && !TemplateUtils.isBinary(output, BinType.VECT_OUTERMULT_ADD))
+ && (!(TemplateUtils.isBinary(output, BinType.VECT_OUTERMULT_ADD) ||
+ !TemplateUtils.isBinary(output, BinType.ROWMAXS_VECTMULT))))
|| output instanceof CNodeTernary
&& ((CNodeTernary)output).getType() == TernaryType.IFELSE)
&& hasOnlyDataNodeOrLookupInputs(output);
@@ -687,4 +693,18 @@ public class TemplateUtils
for( CNode input : current.getInput() )
rFlipVectorLookups(input);
}
+
+ public static boolean containsFusedRowVecAgg(CNodeTpl tpl) {
+ if(!(tpl instanceof CNodeRow))
+ return false;
+
+ if(TemplateUtils.isBinary(tpl.getOutput(), BinType.ROWMAXS_VECTMULT))
+ return true;
+
+ for (CNode n : tpl.getOutput().getInput()) {
+ if(TemplateUtils.isBinary(n, BinType.ROWMAXS_VECTMULT))
+ return true;
+ }
+ return false;
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
index 905b39226d..c618e79607 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
@@ -50,9 +50,23 @@ public class LibSpoofPrimitives
private static ThreadLocal<VectorBuffer> memPool = new ThreadLocal<VectorBuffer>() {
@Override protected VectorBuffer initialValue() { return new VectorBuffer(0,0,0); }
};
-
+
+ public static double rowMaxsVectMult(double[] a, double[] b, int ai, int bi, int len) {
+ double val = Double.NEGATIVE_INFINITY;
+ int j=0;
+ for( int i = ai; i < ai+len; i++ )
+ val = Math.max(a[i]*b[j++], val);
+ return val;
+ }
+
+ public static double rowMaxsVectMult(double[] a, double[] b, int[] aix, int ai, int bi, int len) {
+ double val = Double.NEGATIVE_INFINITY;
+ for( int i = ai; i < ai+len; i++ )
+ val = Math.max(a[i]*b[aix[i]], val);
+ return val;
+ }
+
// forwarded calls to LibMatrixMult
-
public static double dotProduct(double[] a, double[] b, int ai, int bi, int len) {
if( a == null || b == null ) return 0;
return LibMatrixMult.dotProduct(a, b, ai, bi, len);