You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/01/01 22:16:16 UTC
[3/3] incubator-systemml git commit: New broadcast-based spark
grouped aggregate (compiler/runtime)
New broadcast-based spark grouped aggregate (compiler/runtime)
Incl various cleanups and fix reducebykey instead of groupbykey code
path for aggregate operators (e.g., sum()).
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b308c09b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b308c09b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b308c09b
Branch: refs/heads/master
Commit: b308c09b90b3e88e7a30a7d77cc1b3f06bf82cc7
Parents: 4b71654
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Fri Jan 1 13:15:37 2016 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Fri Jan 1 13:15:37 2016 -0800
----------------------------------------------------------------------
.../sysml/hops/ParameterizedBuiltinOp.java | 8 +-
.../org/apache/sysml/lops/GroupedAggregate.java | 19 +++++
.../context/SparkExecutionContext.java | 16 ++++
.../ParameterizedBuiltinSPInstruction.java | 47 +++++++----
.../spark/functions/ExtractGroup.java | 82 ++++++++++++++++----
.../functions/PerformGroupByAggInCombiner.java | 27 ++++---
6 files changed, 153 insertions(+), 46 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b308c09b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
index 99eebaa..44714d4 100644
--- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
@@ -334,7 +334,13 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop
}
else if(et == ExecType.SPARK)
{
- grp_agg = new GroupedAggregate(inputlops, getDataType(), getValueType(), et);
+ //physical operator selection
+ Hop groups = getInput().get(_paramIndexMap.get(Statement.GAGG_GROUPS));
+ boolean broadcastGroups = (_paramIndexMap.get(Statement.GAGG_WEIGHTS) == null &&
+ OptimizerUtils.checkSparkBroadcastMemoryBudget( groups.getDim1(), groups.getDim2(),
+ groups.getRowsInBlock(), groups.getColsInBlock(), groups.getNnz()) );
+
+ grp_agg = new GroupedAggregate(inputlops, getDataType(), getValueType(), et, broadcastGroups);
grp_agg.getOutputParameters().setDimensions(outputDim1, outputDim2, -1, -1, -1);
setRequiresReblock( true );
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b308c09b/src/main/java/org/apache/sysml/lops/GroupedAggregate.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/GroupedAggregate.java b/src/main/java/org/apache/sysml/lops/GroupedAggregate.java
index f54857d..d07cfe0 100644
--- a/src/main/java/org/apache/sysml/lops/GroupedAggregate.java
+++ b/src/main/java/org/apache/sysml/lops/GroupedAggregate.java
@@ -40,6 +40,11 @@ public class GroupedAggregate extends Lop
public static final String COMBINEDINPUT = "combinedinput";
private boolean _weights = false;
+
+ //spark-specific parameters
+ private boolean _broadcastGroups = false;
+
+ //cp-specific parameters
private int _numThreads = 1;
/**
@@ -65,6 +70,14 @@ public class GroupedAggregate extends Lop
public GroupedAggregate(
HashMap<String, Lop> inputParameterLops,
+ DataType dt, ValueType vt, ExecType et, boolean broadcastGroups) {
+ super(Lop.Type.GroupedAgg, dt, vt);
+ init(inputParameterLops, dt, vt, et);
+ _broadcastGroups = broadcastGroups;
+ }
+
+ public GroupedAggregate(
+ HashMap<String, Lop> inputParameterLops,
DataType dt, ValueType vt, ExecType et, int k) {
super(Lop.Type.GroupedAgg, dt, vt);
init(inputParameterLops, dt, vt, et);
@@ -203,6 +216,12 @@ public class GroupedAggregate extends Lop
sb.append( Lop.NAME_VALUE_SEPARATOR );
sb.append( _numThreads );
}
+ else if( getExecType()==ExecType.SPARK ) {
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( "broadcast" );
+ sb.append( Lop.NAME_VALUE_SEPARATOR );
+ sb.append( _broadcastGroups );
+ }
sb.append( OPERAND_DELIMITOR );
sb.append( prepOutputOperand(output));
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b308c09b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index 2ee6041..b46a3af 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -876,6 +876,22 @@ public class SparkExecutionContext extends ExecutionContext
parent.addLineageChild( child );
}
+ /**
+ *
+ * @param varParent
+ * @param varChild
+ * @param broadcast
+ * @throws DMLRuntimeException
+ */
+ public void addLineage(String varParent, String varChild, boolean broadcast)
+ throws DMLRuntimeException
+ {
+ if( broadcast )
+ addLineageBroadcast(varParent, varChild);
+ else
+ addLineageRDD(varParent, varChild);
+ }
+
@Override
public void cleanupMatrixObject( MatrixObject mo )
throws DMLRuntimeException
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b308c09b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 257e174..505e232 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -44,7 +44,8 @@ import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction;
import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcastMatrix;
-import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroup;
+import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroup.ExtractGroupBroadcast;
+import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroup.ExtractGroupJoin;
import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroupNWeights;
import org.apache.sysml.runtime.instructions.spark.functions.PerformGroupByAggInCombiner;
import org.apache.sysml.runtime.instructions.spark.functions.PerformGroupByAggInReducer;
@@ -58,6 +59,7 @@ import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
+import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
import org.apache.sysml.runtime.matrix.operators.Operator;
@@ -177,13 +179,16 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
//opcode guaranteed to be a valid opcode (see parsing)
if ( opcode.equalsIgnoreCase("groupedagg") )
{
+ boolean broadcastGroups = Boolean.parseBoolean(params.get("broadcast"));
+
//get input rdd handle
+ String groupsVar = params.get(Statement.GAGG_GROUPS);
JavaPairRDD<MatrixIndexes,MatrixBlock> target = sec.getBinaryBlockRDDHandleForVariable( params.get(Statement.GAGG_TARGET) );
- JavaPairRDD<MatrixIndexes,MatrixBlock> groups = sec.getBinaryBlockRDDHandleForVariable( params.get(Statement.GAGG_GROUPS) );
+ JavaPairRDD<MatrixIndexes,MatrixBlock> groups = broadcastGroups ? null : sec.getBinaryBlockRDDHandleForVariable( groupsVar );
JavaPairRDD<MatrixIndexes,MatrixBlock> weights = null;
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_TARGET) );
- MatrixCharacteristics mc2 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_GROUPS) );
+ MatrixCharacteristics mc2 = sec.getMatrixCharacteristics( groupsVar );
if(mc1.dimsKnown() && mc2.dimsKnown() && (mc1.getRows() != mc2.getRows() || mc2.getCols() !=1)) {
throw new DMLRuntimeException("Grouped Aggregate dimension mismatch between target and groups.");
}
@@ -195,7 +200,7 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
if ( params.get(Statement.GAGG_WEIGHTS) != null ) {
weights = sec.getBinaryBlockRDDHandleForVariable( params.get(Statement.GAGG_WEIGHTS) );
- MatrixCharacteristics mc3 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_GROUPS) );
+ MatrixCharacteristics mc3 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_WEIGHTS) );
if(mc1.dimsKnown() && mc3.dimsKnown() && (mc1.getRows() != mc3.getRows() || mc1.getCols() != mc3.getCols())) {
throw new DMLRuntimeException("Grouped Aggregate dimension mismatch between target, groups, and weights.");
}
@@ -205,19 +210,26 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
}
else //input vector or matrix
{
- long ngroups = -1;
- if ( params.get(Statement.GAGG_NUM_GROUPS) != null) {
- ngroups = (long) Double.parseDouble(params.get(Statement.GAGG_NUM_GROUPS));
- }
+ String ngroupsStr = params.get(Statement.GAGG_NUM_GROUPS);
+ long ngroups = (ngroupsStr != null) ? (long) Double.parseDouble(ngroupsStr) : -1;
- //replicate groups if necessary
- if( mc1.getNumColBlocks() > 1 ) {
- groups = groups.flatMapToPair(
+ //execute basic grouped aggregate (extract and preagg)
+ if( broadcastGroups ) {
+ PartitionedBroadcastMatrix pbm = sec.getBroadcastForVariable(groupsVar);
+ groupWeightedCells = target
+ .flatMapToPair(new ExtractGroupBroadcast(pbm, mc1.getColsPerBlock(), ngroups, _optr));
+ }
+ else { //general case
+
+ //replicate groups if necessary
+ if( mc1.getNumColBlocks() > 1 ) {
+ groups = groups.flatMapToPair(
new ReplicateVectorFunction(false, mc1.getNumColBlocks() ));
+ }
+
+ groupWeightedCells = groups.join(target)
+ .flatMapToPair(new ExtractGroupJoin(mc1.getColsPerBlock(), ngroups, _optr));
}
-
- groupWeightedCells = groups.join(target)
- .flatMapToPair(new ExtractGroup(mc1.getColsPerBlock(), ngroups, _optr));
}
// Step 2: Make sure we have brlen required while creating <MatrixIndexes, MatrixCell>
@@ -228,7 +240,8 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
// Step 3: Now perform grouped aggregate operation (either on combiner side or reducer side)
JavaPairRDD<MatrixIndexes, MatrixCell> out = null;
- if(_optr instanceof CMOperator && ((CMOperator) _optr).isPartialAggregateOperator() ) {
+ if(_optr instanceof CMOperator && ((CMOperator) _optr).isPartialAggregateOperator()
+ || _optr instanceof AggregateOperator ) {
out = groupWeightedCells.reduceByKey(new PerformGroupByAggInCombiner(_optr))
.mapValues(new CreateMatrixCell(brlen, _optr));
}
@@ -244,8 +257,8 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
//store output rdd handle
sec.setRDDHandleForVariable(output.getName(), out);
- sec.addLineageRDD(output.getName(), params.get(Statement.GAGG_TARGET) );
- sec.addLineageRDD(output.getName(), params.get(Statement.GAGG_GROUPS) );
+ sec.addLineageRDD( output.getName(), params.get(Statement.GAGG_TARGET) );
+ sec.addLineage( output.getName(), groupsVar, broadcastGroups );
if ( params.get(Statement.GAGG_WEIGHTS) != null ) {
sec.addLineageRDD(output.getName(), params.get(Statement.GAGG_WEIGHTS) );
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b308c09b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
index 283e710..e8c6dbe 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
@@ -19,6 +19,7 @@
package org.apache.sysml.runtime.instructions.spark.functions;
+import java.io.Serializable;
import java.util.ArrayList;
import org.apache.spark.api.java.function.PairFlatMapFunction;
@@ -26,6 +27,7 @@ import org.apache.spark.api.java.function.PairFlatMapFunction;
import scala.Tuple2;
import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcastMatrix;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
@@ -33,13 +35,13 @@ import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.util.UtilFunctions;
-public class ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes,Tuple2<MatrixBlock, MatrixBlock>>, MatrixIndexes, WeightedCell> {
-
+public abstract class ExtractGroup implements Serializable
+{
private static final long serialVersionUID = -7059358143841229966L;
- private long _bclen = -1;
- private long _ngroups = -1;
- private Operator _op = null;
+ protected long _bclen = -1;
+ protected long _ngroups = -1;
+ protected Operator _op = null;
public ExtractGroup( long bclen, long ngroups, Operator op ) {
_bclen = bclen;
@@ -47,15 +49,16 @@ public class ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes,T
_op = op;
}
- @Override
- public Iterable<Tuple2<MatrixIndexes, WeightedCell>> call(
- Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg)
- throws Exception
+ /**
+ *
+ * @param ix
+ * @param group
+ * @param target
+ * @return
+ * @throws Exception
+ */
+ protected Iterable<Tuple2<MatrixIndexes, WeightedCell>> execute(MatrixIndexes ix, MatrixBlock group, MatrixBlock target) throws Exception
{
- MatrixIndexes ix = arg._1;
- MatrixBlock group = arg._2._1;
- MatrixBlock target = arg._2._2;
-
//sanity check matching block dimensions
if(group.getNumRows() != target.getNumRows()) {
throw new Exception("The blocksize for group and target blocks are mismatched: " + group.getNumRows() + " != " + target.getNumRows());
@@ -102,6 +105,57 @@ public class ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes,T
}
}
- return groupValuePairs;
+ return groupValuePairs;
+ }
+
+ /**
+ *
+ */
+ public static class ExtractGroupJoin extends ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes,Tuple2<MatrixBlock, MatrixBlock>>, MatrixIndexes, WeightedCell>
+ {
+ private static final long serialVersionUID = 8890978615936560266L;
+
+ public ExtractGroupJoin(long bclen, long ngroups, Operator op) {
+ super(bclen, ngroups, op);
+ }
+
+ @Override
+ public Iterable<Tuple2<MatrixIndexes, WeightedCell>> call(
+ Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg)
+ throws Exception
+ {
+ MatrixIndexes ix = arg._1;
+ MatrixBlock group = arg._2._1;
+ MatrixBlock target = arg._2._2;
+
+ return execute(ix, group, target);
+ }
+ }
+
+ /**
+ *
+ */
+ public static class ExtractGroupBroadcast extends ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>, MatrixIndexes, WeightedCell>
+ {
+ private static final long serialVersionUID = 5709955602290131093L;
+
+ private PartitionedBroadcastMatrix _pbm = null;
+
+ public ExtractGroupBroadcast( PartitionedBroadcastMatrix pbm, long bclen, long ngroups, Operator op ) {
+ super(bclen, ngroups, op);
+ _pbm = pbm;
+ }
+
+ @Override
+ public Iterable<Tuple2<MatrixIndexes, WeightedCell>> call(
+ Tuple2<MatrixIndexes, MatrixBlock> arg)
+ throws Exception
+ {
+ MatrixIndexes ix = arg._1;
+ MatrixBlock group = _pbm.getMatrixBlock((int)ix.getRowIndex(), 1);
+ MatrixBlock target = arg._2;
+
+ return execute(ix, group, target);
+ }
}
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b308c09b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInCombiner.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInCombiner.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInCombiner.java
index 812ff15..8b0cbc8 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInCombiner.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInCombiner.java
@@ -35,31 +35,30 @@ public class PerformGroupByAggInCombiner implements Function2<WeightedCell, Weig
private static final long serialVersionUID = -813530414567786509L;
- Operator op;
+ private Operator _op;
+
public PerformGroupByAggInCombiner(Operator op) {
- this.op = op;
+ _op = op;
}
@Override
- public WeightedCell call(WeightedCell value1, WeightedCell value2) throws Exception {
- return doAggregation(op, value1, value2);
- }
-
- public WeightedCell doAggregation(Operator op, WeightedCell value1, WeightedCell value2) throws DMLRuntimeException {
+ public WeightedCell call(WeightedCell value1, WeightedCell value2)
+ throws Exception
+ {
WeightedCell outCell = new WeightedCell();
CM_COV_Object cmObj = new CM_COV_Object();
- if(op instanceof CMOperator) //everything except sum
+ if(_op instanceof CMOperator) //everything except sum
{
- if( ((CMOperator) op).isPartialAggregateOperator() )
+ if( ((CMOperator) _op).isPartialAggregateOperator() )
{
cmObj.reset();
- CM lcmFn = CM.getCMFnObject(((CMOperator) op).aggOpType); // cmFn.get(key.getTag());
+ CM lcmFn = CM.getCMFnObject(((CMOperator) _op).aggOpType); // cmFn.get(key.getTag());
//partial aggregate cm operator
lcmFn.execute(cmObj, value1.getValue(), value1.getWeight());
lcmFn.execute(cmObj, value2.getValue(), value2.getWeight());
- outCell.setValue(cmObj.getRequiredPartialResult(op));
+ outCell.setValue(cmObj.getRequiredPartialResult(_op));
outCell.setWeight(cmObj.getWeight());
}
else //forward tuples to reducer
@@ -67,9 +66,9 @@ public class PerformGroupByAggInCombiner implements Function2<WeightedCell, Weig
throw new DMLRuntimeException("Incorrect usage, should have used PerformGroupByAggInReducer");
}
}
- else if(op instanceof AggregateOperator) //sum
+ else if(_op instanceof AggregateOperator) //sum
{
- AggregateOperator aggop=(AggregateOperator) op;
+ AggregateOperator aggop=(AggregateOperator) _op;
if( aggop.correctionExists ) {
KahanObject buffer=new KahanObject(aggop.initialValue, 0);
@@ -96,7 +95,7 @@ public class PerformGroupByAggInCombiner implements Function2<WeightedCell, Weig
}
}
else
- throw new DMLRuntimeException("Unsupported operator in grouped aggregate instruction:" + op);
+ throw new DMLRuntimeException("Unsupported operator in grouped aggregate instruction:" + _op);
return outCell;
}