You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/09/25 01:58:35 UTC
systemml git commit: [SYSTEMML-1933] Generalized codegen cbind
handling in row-wise ops
Repository: systemml
Updated Branches:
refs/heads/master 47ce14fc6 -> c1db484d6
[SYSTEMML-1933] Generalized codegen cbind handling in row-wise ops
This patch generalizes the compilation of cbind operations in codegen
row templates. So far, we only supported cbind with a vector of zeros,
and cbind closed the row template. We now support cbind operations (with
vectors of arbitrary constants) in the middle of row templates, which
also allows for multiple cbind operations in a single fused operator.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/c1db484d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/c1db484d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/c1db484d
Branch: refs/heads/master
Commit: c1db484d6119459f7ef6a566ff2663cca286f7ab
Parents: 47ce14f
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun Sep 24 18:29:17 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sun Sep 24 18:29:17 2017 -0700
----------------------------------------------------------------------
.../java/org/apache/sysml/hops/DataGenOp.java | 4 +++
.../sysml/hops/codegen/SpoofCompiler.java | 2 ++
.../apache/sysml/hops/codegen/SpoofFusedOp.java | 21 ++++++++-----
.../sysml/hops/codegen/cplan/CNodeBinary.java | 28 ++++++++++++++---
.../sysml/hops/codegen/cplan/CNodeRow.java | 30 +++++++++++-------
.../sysml/hops/codegen/cplan/CNodeUnary.java | 6 +---
.../hops/codegen/template/TemplateRow.java | 15 ++++++---
.../hops/codegen/template/TemplateUtils.java | 6 ++--
.../sysml/hops/rewrite/HopRewriteUtils.java | 10 ++++++
.../runtime/codegen/LibSpoofPrimitives.java | 24 +++++++++++++++
.../sysml/runtime/codegen/SpoofRowwise.java | 20 +++++++-----
.../instructions/spark/SpoofSPInstruction.java | 5 +--
.../functions/codegen/RowAggTmplTest.java | 20 +++++++++++-
.../scripts/functions/codegen/rowAggPattern31.R | 32 ++++++++++++++++++++
.../functions/codegen/rowAggPattern31.dml | 27 +++++++++++++++++
15 files changed, 202 insertions(+), 48 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/DataGenOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DataGenOp.java b/src/main/java/org/apache/sysml/hops/DataGenOp.java
index 89a5814..eb04ed3 100644
--- a/src/main/java/org/apache/sysml/hops/DataGenOp.java
+++ b/src/main/java/org/apache/sysml/hops/DataGenOp.java
@@ -434,6 +434,10 @@ public class DataGenOp extends Hop implements MultiThreadedHop
return ret;
}
+ public Hop getConstantValue() {
+ return getInput().get(_paramIndexMap.get(DataExpression.RAND_MIN));
+ }
+
public void setIncrementValue(double incr)
{
_incr = incr;
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index b98342c..a374cf1 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -645,6 +645,8 @@ public class SpoofCompiler
HopRewriteUtils.setOutputParametersForScalar(hnew);
hnew = HopRewriteUtils.createUnary(hnew, OpOp1.CAST_AS_MATRIX);
}
+ else if( tmpCNode instanceof CNodeRow && ((CNodeRow)tmpCNode).getRowType()==RowType.NO_AGG_CONST )
+ ((SpoofFusedOp)hnew).setConstDim2(((CNodeRow)tmpCNode).getConstDim2());
if( !(tmpCNode instanceof CNodeMultiAgg) )
HopRewriteUtils.rewireAllParentChildReferences(hop, hnew);
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
index 247a142..81b226d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
@@ -38,8 +38,8 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop
{
public enum SpoofOutputDimsType {
INPUT_DIMS,
+ INPUT_DIMS_CONST2,
ROW_DIMS,
- ROW_DIMS2,
COLUMN_DIMS_ROWS,
COLUMN_DIMS_COLS,
SCALAR,
@@ -52,6 +52,7 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop
private Class<?> _class = null;
private boolean _distSupported = false;
private int _numThreads = -1;
+ private long _constDim2 = -1;
private SpoofOutputDimsType _dimsType;
public SpoofFusedOp ( ) {
@@ -82,6 +83,10 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop
public boolean allowsAllExecTypes() {
return _distSupported;
}
+
+ public void setConstDim2(long constDim2) {
+ _constDim2 = constDim2;
+ }
@Override
protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
@@ -152,9 +157,6 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop
case ROW_DIMS:
ret = new long[]{mc.getRows(), 1, -1};
break;
- case ROW_DIMS2:
- ret = new long[]{mc.getRows(), 2, -1};
- break;
case COLUMN_DIMS_ROWS:
ret = new long[]{mc.getCols(), 1, -1};
break;
@@ -164,6 +166,9 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop
case INPUT_DIMS:
ret = new long[]{mc.getRows(), mc.getCols(), -1};
break;
+ case INPUT_DIMS_CONST2:
+ ret = new long[]{mc.getRows(), _constDim2, -1};
+ break;
case SCALAR:
ret = new long[]{0, 0, -1};
break;
@@ -206,10 +211,6 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop
setDim1(getInput().get(0).getDim1());
setDim2(1);
break;
- case ROW_DIMS2:
- setDim1(getInput().get(0).getDim1());
- setDim2(2);
- break;
case COLUMN_DIMS_ROWS:
setDim1(getInput().get(0).getDim2());
setDim2(1);
@@ -222,6 +223,10 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop
setDim1(getInput().get(0).getDim1());
setDim2(getInput().get(0).getDim2());
break;
+ case INPUT_DIMS_CONST2:
+ setDim1(getInput().get(0).getDim1());
+ setDim2(_constDim2);
+ break;
case SCALAR:
setDim1(0);
setDim2(0);
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
index 926dd4d..bff044d 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeBinary.java
@@ -40,6 +40,7 @@ public class CNodeBinary extends CNode
VECT_POW_SCALAR, VECT_MIN_SCALAR, VECT_MAX_SCALAR,
VECT_EQUAL_SCALAR, VECT_NOTEQUAL_SCALAR, VECT_LESS_SCALAR,
VECT_LESSEQUAL_SCALAR, VECT_GREATER_SCALAR, VECT_GREATEREQUAL_SCALAR,
+ VECT_CBIND,
//vector-vector operations
VECT_MULT, VECT_DIV, VECT_MINUS, VECT_PLUS, VECT_MIN, VECT_MAX, VECT_EQUAL,
VECT_NOTEQUAL, VECT_LESS, VECT_LESSEQUAL, VECT_GREATER, VECT_GREATEREQUAL,
@@ -67,7 +68,7 @@ public class CNodeBinary extends CNode
return ssComm || vsComm || vvComm;
}
- public String getTemplate(boolean sparse, boolean scalarVector) {
+ public String getTemplate(boolean sparse, boolean scalarVector, boolean scalarInput) {
switch (this) {
case DOT_PRODUCT:
return sparse ? " double %TMP% = LibSpoofPrimitives.dotProduct(%IN1v%, %IN2%, %IN1i%, %POS1%, %POS2%, alen);\n" :
@@ -125,6 +126,14 @@ public class CNodeBinary extends CNode
" double[] %TMP% = LibSpoofPrimitives.vect"+vectName+"Write(%IN1%, %IN2%, %POS1%, %LEN%);\n";
}
+ case VECT_CBIND:
+ if( scalarInput )
+ return " double[] %TMP% = LibSpoofPrimitives.vectCBindWrite(%IN1%, %IN2%);\n";
+ else
+ return sparse ?
+ " double[] %TMP% = LibSpoofPrimitives.vectCBindWrite(%IN1v%, %IN2%, %IN1i%, %POS1%, alen, %LEN%);\n" :
+ " double[] %TMP% = LibSpoofPrimitives.vectCBindWrite(%IN1%, %IN2%, %POS1%, %LEN%);\n";
+
//vector-vector operations
case VECT_MULT:
case VECT_DIV:
@@ -202,7 +211,8 @@ public class CNodeBinary extends CNode
|| this == VECT_MIN_SCALAR || this == VECT_MAX_SCALAR
|| this == VECT_EQUAL_SCALAR || this == VECT_NOTEQUAL_SCALAR
|| this == VECT_LESS_SCALAR || this == VECT_LESSEQUAL_SCALAR
- || this == VECT_GREATER_SCALAR || this == VECT_GREATEREQUAL_SCALAR;
+ || this == VECT_GREATER_SCALAR || this == VECT_GREATEREQUAL_SCALAR
+ || this == VECT_CBIND;
}
public boolean isVectorVectorPrimitive() {
return this == VECT_DIV || this == VECT_MULT
@@ -262,10 +272,11 @@ public class CNodeBinary extends CNode
boolean lsparse = sparse && (_inputs.get(0) instanceof CNodeData
&& !_inputs.get(0).getVarname().startsWith("b")
&& !_inputs.get(0).isLiteral());
+ boolean scalarInput = _inputs.get(0).getDataType().isScalar();
boolean scalarVector = (_inputs.get(0).getDataType().isScalar()
&& _inputs.get(1).getDataType().isMatrix());
String var = createVarname();
- String tmp = _type.getTemplate(lsparse, scalarVector);
+ String tmp = _type.getTemplate(lsparse, scalarVector, scalarInput);
tmp = tmp.replace("%TMP%", var);
//replace input references and start indexes
@@ -354,6 +365,7 @@ public class CNodeBinary extends CNode
case VECT_LESSEQUAL: return "b(v2lte)";
case VECT_GREATEREQUAL: return "b(v2gte)";
case VECT_GREATER: return "b(v2gt)";
+ case VECT_CBIND: return "b(cbind)";
case MULT: return "b(*)";
case DIV: return "b(/)";
case PLUS: return "b(+)";
@@ -399,6 +411,12 @@ public class CNodeBinary extends CNode
_dataType = DataType.MATRIX;
break;
+ case VECT_CBIND:
+ _rows = _inputs.get(0)._rows;
+ _cols = _inputs.get(0)._cols+1;
+ _dataType = DataType.MATRIX;
+ break;
+
case VECT_OUTERMULT_ADD:
_rows = _inputs.get(0)._cols;
_cols = _inputs.get(1)._cols;
@@ -465,9 +483,9 @@ public class CNodeBinary extends CNode
case MIN:
case MAX:
case AND:
- case OR:
+ case OR:
case LOG:
- case LOG_NZ:
+ case LOG_NZ:
case POW:
_rows = 0;
_cols = 0;
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
index b74b79d..07822d9 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
@@ -23,7 +23,6 @@ import java.util.ArrayList;
import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary.BinType;
-import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType;
import org.apache.sysml.runtime.util.UtilFunctions;
@@ -40,7 +39,7 @@ public class CNodeRow extends CNodeTpl
+ "\n"
+ "public final class %TMP% extends SpoofRowwise { \n"
+ " public %TMP%() {\n"
- + " super(RowType.%TYPE%, %CBIND0%, %TB1%, %VECT_MEM%);\n"
+ + " super(RowType.%TYPE%, %CONST_DIM2%, %TB1%, %VECT_MEM%);\n"
+ " }\n"
+ " protected void genexec(double[] a, int ai, SideInput[] b, double[] scalars, double[] c, int len, int rowIndex) { \n"
+ "%BODY_dense%"
@@ -59,6 +58,7 @@ public class CNodeRow extends CNodeTpl
}
private RowType _type = null; //access pattern
+ private long _constDim2 = -1; //constant number of output columns
private int _numVectors = -1; //number of intermediate vectors
public void setRowType(RowType type) {
@@ -79,6 +79,14 @@ public class CNodeRow extends CNodeTpl
return _numVectors;
}
+ public void setConstDim2(long dim2) {
+ _constDim2 = dim2;
+ }
+
+ public long getConstDim2() {
+ return _constDim2;
+ }
+
@Override
public void renameInputs() {
rRenameDataNode(_output, _inputs.get(0), "a"); // input matrix
@@ -109,8 +117,7 @@ public class CNodeRow extends CNodeTpl
//replace colvector information and number of vector intermediates
tmp = tmp.replace("%TYPE%", _type.name());
- tmp = tmp.replace("%CBIND0%", String.valueOf(
- TemplateUtils.isUnary(_output, UnaryType.CBIND0)));
+ tmp = tmp.replace("%CONST_DIM2%", String.valueOf(_constDim2));
tmp = tmp.replace("%TB1%", String.valueOf(
TemplateUtils.containsBinary(_output, BinType.VECT_MATRIXMULT)));
tmp = tmp.replace("%VECT_MEM%", String.valueOf(_numVectors));
@@ -122,6 +129,7 @@ public class CNodeRow extends CNodeTpl
switch( _type ) {
case NO_AGG:
case NO_AGG_B1:
+ case NO_AGG_CONST:
return TEMPLATE_NOAGG_OUT.replace("%IN%", varName)
.replace("%LEN%", _output.getVarname()+".length");
case FULL_AGG:
@@ -142,13 +150,13 @@ public class CNodeRow extends CNodeTpl
@Override
public SpoofOutputDimsType getOutputDimType() {
switch( _type ) {
- case NO_AGG: return SpoofOutputDimsType.INPUT_DIMS;
- case NO_AGG_B1: return SpoofOutputDimsType.ROW_RANK_DIMS;
- case FULL_AGG: return SpoofOutputDimsType.SCALAR;
- case ROW_AGG: return TemplateUtils.isUnary(_output, UnaryType.CBIND0) ?
- SpoofOutputDimsType.ROW_DIMS2 : SpoofOutputDimsType.ROW_DIMS;
- case COL_AGG: return SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector
- case COL_AGG_T: return SpoofOutputDimsType.COLUMN_DIMS_ROWS; //column vector
+ case NO_AGG: return SpoofOutputDimsType.INPUT_DIMS;
+ case NO_AGG_B1: return SpoofOutputDimsType.ROW_RANK_DIMS;
+ case NO_AGG_CONST: return SpoofOutputDimsType.INPUT_DIMS_CONST2;
+ case FULL_AGG: return SpoofOutputDimsType.SCALAR;
+ case ROW_AGG: return SpoofOutputDimsType.ROW_DIMS;
+ case COL_AGG: return SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector
+ case COL_AGG_T: return SpoofOutputDimsType.COLUMN_DIMS_ROWS; //column vector
case COL_AGG_B1: return SpoofOutputDimsType.COLUMN_RANK_DIMS;
case COL_AGG_B1_T: return SpoofOutputDimsType.COLUMN_RANK_DIMS_T;
default:
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index 4bfb74b..343efb5 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -28,7 +28,7 @@ import org.apache.sysml.runtime.util.UtilFunctions;
public class CNodeUnary extends CNode
{
public enum UnaryType {
- LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, CBIND0, //codegen specific
+ LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, //codegen specific
ROW_SUMS, ROW_MINS, ROW_MAXS, //codegen specific
VECT_EXP, VECT_POW2, VECT_MULT2, VECT_SQRT, VECT_LOG,
VECT_ABS, VECT_ROUND, VECT_CEIL, VECT_FLOOR, VECT_SIGN,
@@ -94,8 +94,6 @@ public class CNodeUnary extends CNode
return " double %TMP% = getValue(%IN1%, n, rowIndex, colIndex);\n";
case LOOKUP0:
return " double %TMP% = %IN1%[0];\n" ;
- case CBIND0:
- return " double %TMP% = %IN1%; rowIndex *= 2;\n" ;
case POW2:
return " double %TMP% = %IN1% * %IN1%;\n" ;
case MULT2:
@@ -266,7 +264,6 @@ public class CNodeUnary extends CNode
case LOOKUP_C: return "u(ixc)";
case LOOKUP_RC: return "u(ixrc)";
case LOOKUP0: return "u(ix0)";
- case CBIND0: return "u(cbind0)";
case POW2: return "^2";
default: return "u("+_type.name().toLowerCase()+")";
}
@@ -310,7 +307,6 @@ public class CNodeUnary extends CNode
case LOOKUP_C:
case LOOKUP_RC:
case LOOKUP0:
- case CBIND0:
case POW2:
case MULT2:
case ABS:
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index de94969..d9209be 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -50,6 +50,7 @@ import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp1;
import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType;
import org.apache.sysml.runtime.matrix.data.LibMatrixMult;
import org.apache.sysml.runtime.matrix.data.Pair;
@@ -76,6 +77,8 @@ public class TemplateRow extends TemplateBase
public boolean open(Hop hop) {
return (hop instanceof BinaryOp && hop.dimsKnown() && isValidBinaryOperation(hop)
&& hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1)
+ || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix()
+ && HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1)))
|| (hop instanceof AggBinaryOp && hop.dimsKnown() && hop.getDim2()==1 //MV
&& hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1)
|| (hop instanceof AggBinaryOp && hop.dimsKnown() && LibMatrixMult.isSkinnyRightHandSide(
@@ -98,8 +101,7 @@ public class TemplateRow extends TemplateBase
return !isClosed() &&
( (hop instanceof BinaryOp && isValidBinaryOperation(hop) )
|| (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().indexOf(input)==0
- && input.getDim2()==1 && hop.getInput().get(1).getDim2()==1
- && HopRewriteUtils.isEmpty(hop.getInput().get(1)))
+ && HopRewriteUtils.isDataGenOpWithConstantValue(hop.getInput().get(1)))
|| ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp)
&& TemplateCell.isValidOperation(hop))
|| (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol
@@ -130,8 +132,7 @@ public class TemplateRow extends TemplateBase
public CloseType close(Hop hop) {
//close on column or full aggregate (e.g., colSums, t(X)%*%y)
if( (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.Row)
- || (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))
- || HopRewriteUtils.isBinary(hop, OpOp2.CBIND) )
+ || (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))))
return CloseType.CLOSED_VALID;
else
return CloseType.OPEN;
@@ -192,6 +193,8 @@ public class TemplateRow extends TemplateBase
CNodeRow tpl = new CNodeRow(inputs, output);
tpl.setRowType(TemplateUtils.getRowType(hop,
inHops2.get("X"), inHops2.get("B1")));
+ if( tpl.getRowType()==RowType.NO_AGG_CONST )
+ tpl.setConstDim2(hop.getDim2());
tpl.setNumVectorIntermediates(TemplateUtils
.determineMinVectorIntermediates(output));
tpl.getOutput().resetVisitStatus();
@@ -323,7 +326,9 @@ public class TemplateRow extends TemplateBase
{
//special case for cbind with zeros
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
- out = new CNodeUnary(cdata1, UnaryType.CBIND0);
+ CNode cdata2 = TemplateUtils.createCNodeData(
+ HopRewriteUtils.getDataGenOpConstantValue(hop.getInput().get(1)), true);
+ out = new CNodeBinary(cdata1, cdata2, BinType.VECT_CBIND);
inHops.remove(hop.getInput().get(1)); //rm 0-matrix
}
else if(hop instanceof BinaryOp)
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 1924914..95383e6 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -32,7 +32,6 @@ import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.Direction;
-import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.cplan.CNode;
@@ -190,8 +189,7 @@ public class TemplateUtils
|| (output instanceof IndexingOp && HopRewriteUtils.isColumnRangeIndexing((IndexingOp)output)))
&& !(output instanceof AggBinaryOp && HopRewriteUtils.isTransposeOfItself(output.getInput().get(0),X)) )
return RowType.NO_AGG_B1;
- else if( output.getDim1()==X.getDim1() && (output.getDim2()==1
- || HopRewriteUtils.isBinary(output, OpOp2.CBIND))
+ else if( output.getDim1()==X.getDim1() && (output.getDim2()==1)
&& !(output instanceof AggBinaryOp && HopRewriteUtils
.isTransposeOfItself(output.getInput().get(0),X)))
return RowType.ROW_AGG;
@@ -206,6 +204,8 @@ public class TemplateUtils
return RowType.COL_AGG_B1_T;
else if( B1 != null && output.getDim1()==B1.getDim2() && output.getDim2()==X.getDim2())
return RowType.COL_AGG_B1;
+ else if( X.getDim1() == output.getDim1() && X.getDim2() != output.getDim2() )
+ return RowType.NO_AGG_CONST;
else
throw new RuntimeException("Unknown row type for hop "+output.getHopID()+".");
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 7093b0e..af9d593 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -474,12 +474,22 @@ public class HopRewriteUtils
&& ArrayUtils.contains(ops, ((DataGenOp)hop).getOp()));
}
+ public static boolean isDataGenOpWithConstantValue(Hop hop) {
+ return hop instanceof DataGenOp
+ && ((DataGenOp)hop).getOp()==DataGenMethod.RAND
+ && ((DataGenOp)hop).hasConstantValue();
+ }
+
public static boolean isDataGenOpWithConstantValue(Hop hop, double value) {
return hop instanceof DataGenOp
&& ((DataGenOp)hop).getOp()==DataGenMethod.RAND
&& ((DataGenOp)hop).hasConstantValue(value);
}
+ public static Hop getDataGenOpConstantValue(Hop hop) {
+ return ((DataGenOp) hop).getConstantValue();
+ }
+
public static ReorgOp createTranspose(Hop input) {
return createReorg(input, ReOrgOp.TRANSPOSE);
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
index 1a17793..6b4aad7 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/LibSpoofPrimitives.java
@@ -193,6 +193,30 @@ public class LibSpoofPrimitives
for( int i=0; i<aix.length; i++ )
c[aix[i]] = a[i];
}
+
+ // cbind handling
+
+ public static double[] vectCBindWrite(double a, double b) {
+ double[] c = allocVector(2, false);
+ c[0] = a;
+ c[1] = b;
+ return c;
+ }
+
+ public static double[] vectCBindWrite(double[] a, double b, int aix, int len) {
+ double[] c = allocVector(len+1, false);
+ System.arraycopy(a, aix, c, 0, len);
+ c[len] = b;
+ return c;
+ }
+
+ public static double[] vectCBindWrite(double[] a, double b, int[] aix, int ai, int alen, int len) {
+ double[] c = allocVector(len+1, false);
+ for( int j = ai; j < ai+alen; j++ )
+ c[aix[j]] = a[j];
+ c[len] = b;
+ return c;
+ }
// custom vector sums, mins, maxs
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
index 659059e..2464b15 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
@@ -49,6 +49,7 @@ public abstract class SpoofRowwise extends SpoofOperator
public enum RowType {
NO_AGG, //no aggregation
NO_AGG_B1, //no aggregation w/ matrix mult B1
+ NO_AGG_CONST, //no aggregation w/ expansion/contraction
FULL_AGG, //full row/col aggregation
ROW_AGG, //row aggregation (e.g., rowSums() or X %*% v)
COL_AGG, //col aggregation (e.g., colSums() or t(y) %*% X)
@@ -69,13 +70,13 @@ public abstract class SpoofRowwise extends SpoofOperator
}
protected final RowType _type;
- protected final boolean _cbind0;
+ protected final long _constDim2;
protected final boolean _tB1;
protected final int _reqVectMem;
- public SpoofRowwise(RowType type, boolean cbind0, boolean tB1, int reqVectMem) {
+ public SpoofRowwise(RowType type, long constDim2, boolean tB1, int reqVectMem) {
_type = type;
- _cbind0 = cbind0;
+ _constDim2 = constDim2;
_tB1 = tB1;
_reqVectMem = reqVectMem;
}
@@ -84,8 +85,8 @@ public abstract class SpoofRowwise extends SpoofOperator
return _type;
}
- public boolean isCBind0() {
- return _cbind0;
+ public long getConstDim2() {
+ return _constDim2;
}
public int getNumIntermediates() {
@@ -124,7 +125,8 @@ public abstract class SpoofRowwise extends SpoofOperator
//result allocation and preparations
final int m = inputs.get(0).getNumRows();
final int n = inputs.get(0).getNumColumns();
- final int n2 = _type.isRowTypeB1() || hasMatrixSideInput(inputs) ?
+ final int n2 = (_type==RowType.NO_AGG_CONST) ? (int)_constDim2 :
+ _type.isRowTypeB1() || hasMatrixSideInput(inputs) ?
getMinColsMatrixSideInputs(inputs) : -1;
if( !aggIncr || !out.isAllocated() )
allocateOutputMatrix(m, n, n2, out);
@@ -179,7 +181,8 @@ public abstract class SpoofRowwise extends SpoofOperator
//result allocation and preparations
final int m = inputs.get(0).getNumRows();
final int n = inputs.get(0).getNumColumns();
- final int n2 = _type.isRowTypeB1() || hasMatrixSideInput(inputs) ?
+ final int n2 = (_type==RowType.NO_AGG_CONST) ? (int)_constDim2 :
+ _type.isRowTypeB1() || hasMatrixSideInput(inputs) ?
getMinColsMatrixSideInputs(inputs) : -1;
allocateOutputMatrix(m, n, n2, out);
final boolean flipOut = _type.isRowTypeB1ColumnAgg()
@@ -258,8 +261,9 @@ public abstract class SpoofRowwise extends SpoofOperator
switch( _type ) {
case NO_AGG: out.reset(m, n, false); break;
case NO_AGG_B1: out.reset(m, n2, false); break;
+ case NO_AGG_CONST: out.reset(m, (int)_constDim2, false); break;
case FULL_AGG: out.reset(1, 1, false); break;
- case ROW_AGG: out.reset(m, 1+(_cbind0?1:0), false); break;
+ case ROW_AGG: out.reset(m, 1, false); break;
case COL_AGG: out.reset(1, n, false); break;
case COL_AGG_T: out.reset(n, 1, false); break;
case COL_AGG_B1: out.reset(n2, n, false); break;
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
index 2d609aa..a628dc0 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
@@ -354,7 +354,7 @@ public class SpoofSPInstruction extends SPInstruction {
if( type == RowType.NO_AGG )
mcOut.set(mcIn);
else if( type == RowType.ROW_AGG )
- mcOut.set(mcIn.getRows(), ((SpoofRowwise)op).isCBind0()? 2:1,
+ mcOut.set(mcIn.getRows(), 1,
mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
else if( type == RowType.COL_AGG )
mcOut.set(1, mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
@@ -454,7 +454,8 @@ public class SpoofSPInstruction extends SPInstruction {
}
//setup local memory for reuse
- int clen2 = (int) (_op.getRowType().isRowTypeB1() ? _inputs.get(0).getNumCols() : -1);
+ int clen2 = (int) ((_op.getRowType()==RowType.NO_AGG_CONST) ? _op.getConstDim2() :
+ _op.getRowType().isRowTypeB1() ? _inputs.get(0).getNumCols() : -1);
LibSpoofPrimitives.setupThreadLocalMemory(_op.getNumIntermediates(), _clen, clen2);
ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>();
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
index 3ecfd6b..d4f87b3 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
@@ -67,6 +67,7 @@ public class RowAggTmplTest extends AutomatedTestBase
private static final String TEST_NAME28 = TEST_NAME+"28"; //Kmeans, final eval
private static final String TEST_NAME29 = TEST_NAME+"29"; //sum(rowMins(X))
private static final String TEST_NAME30 = TEST_NAME+"30"; //Mlogreg inner core, multi-class
+ private static final String TEST_NAME31 = TEST_NAME+"31"; //MLogreg - matrix-vector cbind 0s generalized
private static final String TEST_DIR = "functions/codegen/";
private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/";
@@ -78,7 +79,7 @@ public class RowAggTmplTest extends AutomatedTestBase
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- for(int i=1; i<=30; i++)
+ for(int i=1; i<=31; i++)
addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) );
}
@@ -532,6 +533,21 @@ public class RowAggTmplTest extends AutomatedTestBase
testCodegenIntegration( TEST_NAME30, false, ExecType.SPARK );
}
+ @Test
+ public void testCodegenRowAggRewrite31CP() {
+ testCodegenIntegration( TEST_NAME31, true, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg31CP() {
+ testCodegenIntegration( TEST_NAME31, false, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg31SP() {
+ testCodegenIntegration( TEST_NAME31, false, ExecType.SPARK );
+ }
+
private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
{
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -581,6 +597,8 @@ public class RowAggTmplTest extends AutomatedTestBase
if( testname.equals(TEST_NAME30) )
Assert.assertTrue(!heavyHittersContainsSubString("spoofRA", 2)
&& !heavyHittersContainsSubString(RightIndex.OPCODE));
+ if( testname.equals(TEST_NAME31) )
+ Assert.assertTrue(!heavyHittersContainsSubString("spoofRA", 2));
}
finally {
rtplatform = platformOld;
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/test/scripts/functions/codegen/rowAggPattern31.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern31.R b/src/test/scripts/functions/codegen/rowAggPattern31.R
new file mode 100644
index 0000000..036a3e2
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern31.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = matrix(seq(1,1500), 150, 10, byrow=TRUE);
+v = seq(1, ncol(X));
+R = cbind((X %*% v), matrix (7, nrow(X), 1))
+R = R - rowMaxs(R) %*% matrix(1, 1, ncol(R));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/c1db484d/src/test/scripts/functions/codegen/rowAggPattern31.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern31.dml b/src/test/scripts/functions/codegen/rowAggPattern31.dml
new file mode 100644
index 0000000..8bdefc4
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern31.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(1,1500), 150, 10);
+v = seq(1, ncol(X));
+R = cbind((X %*% v), matrix (7, nrow(X), 1))
+R = R - rowMaxs (R)
+
+write(R, $1)