You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/06/12 18:12:59 UTC
[2/2] systemml git commit: [SYSTEMML-1683] Improved codegen row
template (indexing, cbind)
[SYSTEMML-1683] Improved codegen row template (indexing, cbind)
This patch makes two improvements to the code generator row-wise
template in order to further reduce the number of intermediates in
scripts such as MLogreg as well as minor explain improvements):
(1) Column indexing support w/ unknown indexing expressions.
(2) Fusion of cbind with empty matrix after row-wise template (row
aggregates).
(3) Extended explain to show the line numbers (of the original script)
for generated operators.
For example, on MLogreg and a dense 100M x 10 scenario, this reduces the
buffer pool writes from (76/30/1) to (60/30/1) and execution time from
282s to 256s (compared to the baseline w/ existing fused operators of
529s).
Note that there is substantial additional potential, which can be
exploited to reduce the number evictions. However, this will be
addressed in separate changes as it requires optimizer changes (e.g.,
considering multi-aggregates while considering materialization, and
materialization points per consumer).
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/1f508911
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/1f508911
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/1f508911
Branch: refs/heads/master
Commit: 1f508911052f035580c2b9120912dc63c47804d2
Parents: e54ed71
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun Jun 11 22:53:12 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sun Jun 11 22:53:29 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/codegen/SpoofCompiler.java | 8 +--
.../apache/sysml/hops/codegen/SpoofFusedOp.java | 7 ++-
.../sysml/hops/codegen/cplan/CNodeRow.java | 9 +++-
.../sysml/hops/codegen/cplan/CNodeTpl.java | 10 ++++
.../sysml/hops/codegen/cplan/CNodeUnary.java | 12 +++--
.../hops/codegen/template/TemplateCell.java | 4 +-
.../hops/codegen/template/TemplateMultiAgg.java | 1 +
.../codegen/template/TemplateOuterProduct.java | 4 +-
.../hops/codegen/template/TemplateRow.java | 15 +++++-
.../hops/codegen/template/TemplateUtils.java | 4 +-
.../sysml/runtime/codegen/SpoofRowwise.java | 10 +++-
.../instructions/spark/SpoofSPInstruction.java | 3 +-
.../functions/codegen/RowAggTmplTest.java | 55 ++++++++++++++++++--
.../scripts/functions/codegen/rowAggPattern17.R | 33 ++++++++++++
.../functions/codegen/rowAggPattern17.dml | 33 ++++++++++++
.../scripts/functions/codegen/rowAggPattern18.R | 31 +++++++++++
.../functions/codegen/rowAggPattern18.dml | 26 +++++++++
.../scripts/functions/codegen/rowAggPattern19.R | 35 +++++++++++++
.../functions/codegen/rowAggPattern19.dml | 32 ++++++++++++
19 files changed, 311 insertions(+), 21 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
index d96dda1..857c187 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofCompiler.java
@@ -367,12 +367,14 @@ public class SpoofCompiler
//explain debug output cplans or generated source code
if( LOG.isTraceEnabled() || DMLScript.EXPLAIN.isHopsType(recompile) ) {
- LOG.info("Codegen EXPLAIN (generated cplan for HopID: " + cplan.getKey() +"):");
+ LOG.info("Codegen EXPLAIN (generated cplan for HopID: "
+ + cplan.getKey() + ", line "+tmp.getValue().getBeginLine() + "):");
LOG.info(tmp.getValue().getClassname()
- +Explain.explainCPlan(cplan.getValue().getValue()));
+ + Explain.explainCPlan(cplan.getValue().getValue()));
}
if( LOG.isTraceEnabled() || DMLScript.EXPLAIN.isRuntimeType(recompile) ) {
- LOG.info("Codegen EXPLAIN (generated code for HopID: " + cplan.getKey() +"):");
+ LOG.info("Codegen EXPLAIN (generated code for HopID: "
+ + cplan.getKey() + ", line "+tmp.getValue().getBeginLine() + "):");
LOG.info(src);
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
index 779df4f..9f426f6 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/SpoofFusedOp.java
@@ -38,6 +38,7 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop
public enum SpoofOutputDimsType {
INPUT_DIMS,
ROW_DIMS,
+ ROW_DIMS2,
COLUMN_DIMS_ROWS,
COLUMN_DIMS_COLS,
SCALAR,
@@ -148,6 +149,10 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop
setDim1(getInput().get(0).getDim1());
setDim2(1);
break;
+ case ROW_DIMS2:
+ setDim1(getInput().get(0).getDim1());
+ setDim2(2);
+ break;
case COLUMN_DIMS_ROWS:
setDim1(getInput().get(0).getDim2());
setDim2(1);
@@ -185,7 +190,7 @@ public class SpoofFusedOp extends Hop implements MultiThreadedHop
@Override
public Object clone() throws CloneNotSupportedException
{
- SpoofFusedOp ret = new SpoofFusedOp();
+ SpoofFusedOp ret = new SpoofFusedOp();
//copy generic attributes
ret.clone(this, false);
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
index ac2a394..27565c1 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeRow.java
@@ -23,6 +23,8 @@ import java.util.ArrayList;
import java.util.Arrays;
import org.apache.sysml.hops.codegen.SpoofFusedOp.SpoofOutputDimsType;
+import org.apache.sysml.hops.codegen.cplan.CNodeUnary.UnaryType;
+import org.apache.sysml.hops.codegen.template.TemplateUtils;
import org.apache.sysml.runtime.codegen.SpoofRowwise.RowType;
public class CNodeRow extends CNodeTpl
@@ -36,7 +38,7 @@ public class CNodeRow extends CNodeTpl
+ "\n"
+ "public final class %TMP% extends SpoofRowwise { \n"
+ " public %TMP%() {\n"
- + " super(RowType.%TYPE%, %VECT_MEM%);\n"
+ + " super(RowType.%TYPE%, %CBIND0%, %VECT_MEM%);\n"
+ " }\n"
+ " protected void genexecRowDense( double[] a, int ai, double[][] b, double[] scalars, double[] c, int len, int rowIndex ) { \n"
+ "%BODY_dense%"
@@ -101,6 +103,8 @@ public class CNodeRow extends CNodeTpl
//replace colvector information and number of vector intermediates
tmp = tmp.replaceAll("%TYPE%", _type.name());
+ tmp = tmp.replaceAll("%CBIND0%", String.valueOf(
+ TemplateUtils.isUnary(_output, UnaryType.CBIND0)));
tmp = tmp.replaceAll("%VECT_MEM%", String.valueOf(_numVectors));
return tmp;
@@ -125,7 +129,8 @@ public class CNodeRow extends CNodeTpl
public SpoofOutputDimsType getOutputDimType() {
switch( _type ) {
case NO_AGG: return SpoofOutputDimsType.INPUT_DIMS;
- case ROW_AGG: return SpoofOutputDimsType.ROW_DIMS;
+ case ROW_AGG: return TemplateUtils.isUnary(_output, UnaryType.CBIND0) ?
+ SpoofOutputDimsType.ROW_DIMS2 : SpoofOutputDimsType.ROW_DIMS;
case COL_AGG: return SpoofOutputDimsType.COLUMN_DIMS_COLS; //row vector
case COL_AGG_T: return SpoofOutputDimsType.COLUMN_DIMS_ROWS; //column vector
default:
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
index ca474d2..81351e6 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTpl.java
@@ -32,6 +32,8 @@ import org.apache.sysml.parser.Expression.DataType;
public abstract class CNodeTpl extends CNode implements Cloneable
{
+ private int _beginLine = -1;
+
public CNodeTpl(ArrayList<CNode> inputs, CNode output ) {
if(inputs.size() < 1)
throw new RuntimeException("Cannot pass empty inputs to the CNodeTpl");
@@ -243,6 +245,14 @@ public abstract class CNodeTpl extends CNode implements Cloneable
return false;
}
+ public void setBeginLine(int line) {
+ _beginLine = line;
+ }
+
+ public int getBeginLine() {
+ return _beginLine;
+ }
+
@Override
public int hashCode() {
return super.hashCode();
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
index b9c7cbe..4c60824 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeUnary.java
@@ -28,7 +28,7 @@ import org.apache.sysml.parser.Expression.DataType;
public class CNodeUnary extends CNode
{
public enum UnaryType {
- LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, //codegen specific
+ LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, CBIND0, //codegen specific
ROW_SUMS, ROW_MINS, ROW_MAXS, //codegen specific
VECT_EXP, VECT_POW2, VECT_MULT2, VECT_SQRT, VECT_LOG,
VECT_ABS, VECT_ROUND, VECT_CEIL, VECT_FLOOR, VECT_SIGN,
@@ -79,6 +79,8 @@ public class CNodeUnary extends CNode
return " double %TMP% = getValue(%IN1%, n, rowIndex, colIndex);\n";
case LOOKUP0:
return " double %TMP% = %IN1%[0];\n" ;
+ case CBIND0:
+ return " double %TMP% = %IN1%; rowIndex *= 2;\n" ;
case POW2:
return " double %TMP% = %IN1% * %IN1%;\n" ;
case MULT2:
@@ -208,10 +210,11 @@ public class CNodeUnary extends CNode
case VECT_CEIL:
case VECT_FLOOR:
case VECT_SIGN: return "u(v"+_type.name().toLowerCase()+")";
- case LOOKUP_R: return "u(ixr)";
- case LOOKUP_C: return "u(ixc)";
+ case LOOKUP_R: return "u(ixr)";
+ case LOOKUP_C: return "u(ixc)";
case LOOKUP_RC: return "u(ixrc)";
- case LOOKUP0: return "u(ix0)";
+ case LOOKUP0: return "u(ix0)";
+ case CBIND0: return "u(cbind0)";
case POW2: return "^2";
default: return "u("+_type.name().toLowerCase()+")";
}
@@ -243,6 +246,7 @@ public class CNodeUnary extends CNode
case LOOKUP_C:
case LOOKUP_RC:
case LOOKUP0:
+ case CBIND0:
case POW2:
case MULT2:
case ABS:
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index 434fa59..d6dcdf6 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -76,7 +76,8 @@ public class TemplateCell extends TemplateBase
@Override
public boolean open(Hop hop) {
return isValidOperation(hop)
- || (hop instanceof IndexingOp && ((IndexingOp)hop).isColLowerEqualsUpper());
+ || (hop instanceof IndexingOp && (((IndexingOp)hop)
+ .isColLowerEqualsUpper() || hop.getDim2()==1));
}
@Override
@@ -135,6 +136,7 @@ public class TemplateCell extends TemplateBase
tpl.setSparseSafe((HopRewriteUtils.isBinary(hop, OpOp2.MULT) && hop.getInput().contains(sinHops.get(0)))
|| (HopRewriteUtils.isBinary(hop, OpOp2.DIV) && hop.getInput().get(0) == sinHops.get(0)));
tpl.setRequiresCastDtm(hop instanceof AggBinaryOp);
+ tpl.setBeginLine(hop.getBeginLine());
// return cplan instance
return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java
index 56477da..a75e07f 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateMultiAgg.java
@@ -110,6 +110,7 @@ public class TemplateMultiAgg extends TemplateCell
CNodeMultiAgg tpl = new CNodeMultiAgg(inputs, outputs);
tpl.setAggOps(aggOps);
tpl.setRootNodes(roots);
+ tpl.setBeginLine(hop.getBeginLine());
// return cplan instance
return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java
index 1dfe773..9f5b191 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateOuterProduct.java
@@ -127,7 +127,9 @@ public class TemplateOuterProduct extends TemplateBase {
CNodeOuterProduct tpl = new CNodeOuterProduct(inputs, output);
tpl.setOutProdType(TemplateUtils.getOuterProductType(X, U, V, hop));
tpl.setTransposeOutput(!HopRewriteUtils.isTransposeOperation(hop)
- && tpl.getOutProdType()==OutProdType.LEFT_OUTER_PRODUCT);
+ && tpl.getOutProdType()==OutProdType.LEFT_OUTER_PRODUCT);
+ tpl.setBeginLine(hop.getBeginLine());
+
return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index ae25ded..a0b4572 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -89,6 +89,9 @@ public class TemplateRow extends TemplateBase
&& (HopRewriteUtils.isBinaryMatrixColVectorOperation(hop)
|| HopRewriteUtils.isBinaryMatrixScalarOperation(hop)
|| HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) )
+ || (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().indexOf(input)==0
+ && input.getDim2()==1 && hop.getInput().get(1).getDim2()==1
+ && HopRewriteUtils.isEmpty(hop.getInput().get(1)))
|| ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp)
&& TemplateCell.isValidOperation(hop))
|| (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol
@@ -111,8 +114,9 @@ public class TemplateRow extends TemplateBase
@Override
public CloseType close(Hop hop) {
//close on column aggregate (e.g., colSums, t(X)%*%y)
- if( hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()==Direction.Col
- || (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) )
+ if( (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()==Direction.Col)
+ || (hop instanceof AggBinaryOp && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0)))
+ || HopRewriteUtils.isBinary(hop, OpOp2.CBIND) )
return CloseType.CLOSED_VALID;
else
return CloseType.OPEN;
@@ -144,6 +148,7 @@ public class TemplateRow extends TemplateBase
.countVectorIntermediates(output, new HashSet<Long>()));
tpl.getOutput().resetVisitStatus();
tpl.rReorderCommutativeBinaryOps(tpl.getOutput(), sinHops.get(0).getHopID());
+ tpl.setBeginLine(hop.getBeginLine());
// return cplan instance
return new Pair<Hop[],CNodeTpl>(sinHops.toArray(new Hop[0]), tpl);
@@ -241,6 +246,12 @@ public class TemplateRow extends TemplateBase
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
}
}
+ else if(HopRewriteUtils.isBinary(hop, OpOp2.CBIND))
+ {
+ //special case for cbind with zeros
+ CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
+ out = new CNodeUnary(cdata1, UnaryType.CBIND0);
+ }
else if(hop instanceof BinaryOp)
{
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index e8e3901..89a02da 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -35,6 +35,7 @@ import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.Hop.OpOp2;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
@@ -239,7 +240,8 @@ public class TemplateUtils
public static RowType getRowType(Hop output, Hop input) {
if( HopRewriteUtils.isEqualSize(output, input) )
return RowType.NO_AGG;
- else if( output.getDim1()==input.getDim1() && output.getDim2()==1
+ else if( output.getDim1()==input.getDim1() && (output.getDim2()==1
+ || HopRewriteUtils.isBinary(output, OpOp2.CBIND))
&& !(output instanceof AggBinaryOp && HopRewriteUtils
.isTransposeOfItself(output.getInput().get(0),input)))
return RowType.ROW_AGG;
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
index 09d5b29..899f629 100644
--- a/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
+++ b/src/main/java/org/apache/sysml/runtime/codegen/SpoofRowwise.java
@@ -54,10 +54,12 @@ public abstract class SpoofRowwise extends SpoofOperator
}
protected final RowType _type;
+ protected final boolean _cbind0;
protected final int _reqVectMem;
- public SpoofRowwise(RowType type, int reqVectMem) {
+ public SpoofRowwise(RowType type, boolean cbind0, int reqVectMem) {
_type = type;
+ _cbind0 = cbind0;
_reqVectMem = reqVectMem;
}
@@ -65,6 +67,10 @@ public abstract class SpoofRowwise extends SpoofOperator
return _type;
}
+ public boolean isCBind0() {
+ return _cbind0;
+ }
+
public int getNumIntermediates() {
return _reqVectMem;
}
@@ -183,7 +189,7 @@ public abstract class SpoofRowwise extends SpoofOperator
private void allocateOutputMatrix(int m, int n, MatrixBlock out) {
switch( _type ) {
case NO_AGG: out.reset(m, n, false); break;
- case ROW_AGG: out.reset(m, 1, false); break;
+ case ROW_AGG: out.reset(m, 1+(_cbind0?1:0), false); break;
case COL_AGG: out.reset(1, n, false); break;
case COL_AGG_T: out.reset(n, 1, false); break;
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
index be3a76d..f11b3d0 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
@@ -267,7 +267,8 @@ public class SpoofSPInstruction extends SPInstruction
if( type == RowType.NO_AGG )
mcOut.set(mcIn);
else if( type == RowType.ROW_AGG )
- mcOut.set(mcIn.getRows(), 1, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
+ mcOut.set(mcIn.getRows(), ((SpoofRowwise)op).isCBind0()? 2:1,
+ mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
else if( type == RowType.COL_AGG )
mcOut.set(1, mcIn.getCols(), mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
else if( type == RowType.COL_AGG_T )
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
index 362f1dc..409fbc7 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/RowAggTmplTest.java
@@ -52,6 +52,9 @@ public class RowAggTmplTest extends AutomatedTestBase
private static final String TEST_NAME14 = TEST_NAME+"14"; //colSums(max(floor(round(abs(min(sign(X+Y),1)))),7))
private static final String TEST_NAME15 = TEST_NAME+"15"; //systemml nn - softmax backward
private static final String TEST_NAME16 = TEST_NAME+"16"; //Y=X-rowIndexMax(X); R=Y/rowSums(Y)
+ private static final String TEST_NAME17 = TEST_NAME+"17"; //MLogreg - vector-matrix w/ indexing
+ private static final String TEST_NAME18 = TEST_NAME+"18"; //MLogreg - matrix-vector cbind 0s
+ private static final String TEST_NAME19 = TEST_NAME+"19"; //MLogreg - rowwise dag
private static final String TEST_DIR = "functions/codegen/";
private static final String TEST_CLASS_DIR = TEST_DIR + RowAggTmplTest.class.getSimpleName() + "/";
@@ -63,7 +66,7 @@ public class RowAggTmplTest extends AutomatedTestBase
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- for(int i=1; i<=16; i++)
+ for(int i=1; i<=19; i++)
addTestConfiguration( TEST_NAME+i, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) }) );
}
@@ -307,6 +310,51 @@ public class RowAggTmplTest extends AutomatedTestBase
testCodegenIntegration( TEST_NAME16, false, ExecType.SPARK );
}
+ @Test
+ public void testCodegenRowAggRewrite17CP() {
+ testCodegenIntegration( TEST_NAME17, true, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg17CP() {
+ testCodegenIntegration( TEST_NAME17, false, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg17SP() {
+ testCodegenIntegration( TEST_NAME17, false, ExecType.SPARK );
+ }
+
+ @Test
+ public void testCodegenRowAggRewrite18CP() {
+ testCodegenIntegration( TEST_NAME18, true, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg18CP() {
+ testCodegenIntegration( TEST_NAME18, false, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg18SP() {
+ testCodegenIntegration( TEST_NAME18, false, ExecType.SPARK );
+ }
+
+ @Test
+ public void testCodegenRowAggRewrite19CP() {
+ testCodegenIntegration( TEST_NAME19, true, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg19CP() {
+ testCodegenIntegration( TEST_NAME19, false, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenRowAgg19SP() {
+ testCodegenIntegration( TEST_NAME19, false, ExecType.SPARK );
+ }
+
private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
{
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -328,7 +376,7 @@ public class RowAggTmplTest extends AutomatedTestBase
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[]{"-explain", "-stats", "-args", output("S") };
+ programArgs = new String[]{"-explain", "recompile_hops", "-stats", "-args", output("S") };
fullRScriptName = HOME + testname + ".R";
rCmd = getRCmd(inputDir(), expectedDir());
@@ -348,7 +396,8 @@ public class RowAggTmplTest extends AutomatedTestBase
//ensure full aggregates for certain patterns
if( testname.equals(TEST_NAME15) )
Assert.assertTrue(!heavyHittersContainsSubString("uark+"));
-
+ if( testname.equals(TEST_NAME17) )
+ Assert.assertTrue(!heavyHittersContainsSubString("rangeReIndex"));
}
finally {
rtplatform = platformOld;
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/test/scripts/functions/codegen/rowAggPattern17.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern17.R b/src/test/scripts/functions/codegen/rowAggPattern17.R
new file mode 100644
index 0000000..44d1474
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern17.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = matrix(seq(1,1500), 150, 10, byrow=TRUE);
+P = matrix(2, nrow(X), 2);
+Y = matrix(1, nrow(X), 2);
+
+R = t(X) %*% (P [, 1] - Y [, 1]);
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/test/scripts/functions/codegen/rowAggPattern17.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern17.dml b/src/test/scripts/functions/codegen/rowAggPattern17.dml
new file mode 100644
index 0000000..1a7ce93
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern17.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(1,1500), 150, 10);
+P = matrix(2, nrow(X), 2);
+Y0 = matrix(1, nrow(X), 1);
+max_y = max(Y0)+1;
+if(1==1){}
+Y = table(seq(1,nrow(Y0)), Y0, nrow(Y0), max_y);
+
+if(1==1){} #recompile w/o knowing K
+K = ncol(Y)-1;
+R = t(X) %*% (P [, 1:K] - Y [, 1:K]);
+
+write(R, $1)
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/test/scripts/functions/codegen/rowAggPattern18.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern18.R b/src/test/scripts/functions/codegen/rowAggPattern18.R
new file mode 100644
index 0000000..c03732a
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern18.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+X = matrix(seq(1,1500), 150, 10, byrow=TRUE);
+v = seq(1, ncol(X));
+R = cbind((X %*% v), matrix (0, nrow(X), 1))
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/test/scripts/functions/codegen/rowAggPattern18.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern18.dml b/src/test/scripts/functions/codegen/rowAggPattern18.dml
new file mode 100644
index 0000000..ff3020c
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern18.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(1,1500), 150, 10);
+v = seq(1, ncol(X));
+R = cbind((X %*% v), matrix (0, nrow(X), 1))
+
+write(R, $1)
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/test/scripts/functions/codegen/rowAggPattern19.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern19.R b/src/test/scripts/functions/codegen/rowAggPattern19.R
new file mode 100644
index 0000000..bb26ff6
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern19.R
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+LT = matrix(1, 1500, 2);
+Y = matrix(2, 1500, 2);
+
+LT = LT - rowMaxs (LT) %*% matrix (1, 1, 2);
+exp_LT = exp (LT);
+R = exp_LT / (rowSums (exp_LT) %*% matrix (1, 1, 2));
+print(sum(Y * LT) + sum(log(rowSums(exp_LT))));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/1f508911/src/test/scripts/functions/codegen/rowAggPattern19.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/rowAggPattern19.dml b/src/test/scripts/functions/codegen/rowAggPattern19.dml
new file mode 100644
index 0000000..cb1fa3d
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern19.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+LT = matrix(1, 1500, 2);
+Y = matrix(2, 1500, 2);
+if(1==1) {}
+
+LT = LT - rowMaxs (LT);
+exp_LT = exp (LT);
+R = exp_LT / rowSums (exp_LT);
+print(sum(Y * LT) + sum(log(rowSums(exp_LT))));
+
+
+write(R, $1)