You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/01/01 22:16:16 UTC

[3/3] incubator-systemml git commit: New broadcast-based spark grouped aggregate (compiler/runtime)

New broadcast-based spark grouped aggregate (compiler/runtime)

Incl various cleanups and fix reducebykey instead of groupbykey code
path for aggregate operators (e.g., sum()).

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b308c09b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b308c09b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b308c09b

Branch: refs/heads/master
Commit: b308c09b90b3e88e7a30a7d77cc1b3f06bf82cc7
Parents: 4b71654
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Fri Jan 1 13:15:37 2016 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Fri Jan 1 13:15:37 2016 -0800

----------------------------------------------------------------------
 .../sysml/hops/ParameterizedBuiltinOp.java      |  8 +-
 .../org/apache/sysml/lops/GroupedAggregate.java | 19 +++++
 .../context/SparkExecutionContext.java          | 16 ++++
 .../ParameterizedBuiltinSPInstruction.java      | 47 +++++++----
 .../spark/functions/ExtractGroup.java           | 82 ++++++++++++++++----
 .../functions/PerformGroupByAggInCombiner.java  | 27 ++++---
 6 files changed, 153 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b308c09b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
index 99eebaa..44714d4 100644
--- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
@@ -334,7 +334,13 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop
 			}
 			else if(et == ExecType.SPARK) 
 			{
-				grp_agg = new GroupedAggregate(inputlops, getDataType(), getValueType(), et);						
+				//physical operator selection
+				Hop groups = getInput().get(_paramIndexMap.get(Statement.GAGG_GROUPS));
+				boolean broadcastGroups = (_paramIndexMap.get(Statement.GAGG_WEIGHTS) == null &&
+						OptimizerUtils.checkSparkBroadcastMemoryBudget( groups.getDim1(), groups.getDim2(), 
+								groups.getRowsInBlock(), groups.getColsInBlock(), groups.getNnz()) );
+				
+				grp_agg = new GroupedAggregate(inputlops, getDataType(), getValueType(), et, broadcastGroups);						
 				grp_agg.getOutputParameters().setDimensions(outputDim1, outputDim2, -1, -1, -1);
 				setRequiresReblock( true );
 			}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b308c09b/src/main/java/org/apache/sysml/lops/GroupedAggregate.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/GroupedAggregate.java b/src/main/java/org/apache/sysml/lops/GroupedAggregate.java
index f54857d..d07cfe0 100644
--- a/src/main/java/org/apache/sysml/lops/GroupedAggregate.java
+++ b/src/main/java/org/apache/sysml/lops/GroupedAggregate.java
@@ -40,6 +40,11 @@ public class GroupedAggregate extends Lop
 	public static final String COMBINEDINPUT = "combinedinput";
 	
 	private boolean _weights = false;	
+	
+	//spark-specific parameters
+	private boolean _broadcastGroups = false;
+	
+	//cp-specific parameters
 	private int _numThreads = 1;
 
 	/**
@@ -65,6 +70,14 @@ public class GroupedAggregate extends Lop
 	
 	public GroupedAggregate(
 			HashMap<String, Lop> inputParameterLops, 
+			DataType dt, ValueType vt, ExecType et, boolean broadcastGroups) {
+		super(Lop.Type.GroupedAgg, dt, vt);
+		init(inputParameterLops, dt, vt, et);
+		_broadcastGroups = broadcastGroups;
+	}
+	
+	public GroupedAggregate(
+			HashMap<String, Lop> inputParameterLops, 
 			DataType dt, ValueType vt, ExecType et, int k) {
 		super(Lop.Type.GroupedAgg, dt, vt);
 		init(inputParameterLops, dt, vt, et);
@@ -203,6 +216,12 @@ public class GroupedAggregate extends Lop
 			sb.append( Lop.NAME_VALUE_SEPARATOR );
 			sb.append( _numThreads );	
 		}
+		else if( getExecType()==ExecType.SPARK ) {
+			sb.append( OPERAND_DELIMITOR );
+			sb.append( "broadcast" );
+			sb.append( Lop.NAME_VALUE_SEPARATOR );
+			sb.append( _broadcastGroups );	
+		}
 		
 		sb.append( OPERAND_DELIMITOR );
 		sb.append( prepOutputOperand(output));

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b308c09b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index 2ee6041..b46a3af 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -876,6 +876,22 @@ public class SparkExecutionContext extends ExecutionContext
 		parent.addLineageChild( child );
 	}
 	
+	/**
+	 * 
+	 * @param varParent
+	 * @param varChild
+	 * @param broadcast
+	 * @throws DMLRuntimeException
+	 */
+	public void addLineage(String varParent, String varChild, boolean broadcast) 
+		throws DMLRuntimeException
+	{
+		if( broadcast )
+			addLineageBroadcast(varParent, varChild);
+		else
+			addLineageRDD(varParent, varChild);
+	}
+	
 	@Override
 	public void cleanupMatrixObject( MatrixObject mo ) 
 		throws DMLRuntimeException

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b308c09b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 257e174..505e232 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -44,7 +44,8 @@ import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.instructions.cp.CPOperand;
 import org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction;
 import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcastMatrix;
-import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroup;
+import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroup.ExtractGroupBroadcast;
+import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroup.ExtractGroupJoin;
 import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroupNWeights;
 import org.apache.sysml.runtime.instructions.spark.functions.PerformGroupByAggInCombiner;
 import org.apache.sysml.runtime.instructions.spark.functions.PerformGroupByAggInReducer;
@@ -58,6 +59,7 @@ import org.apache.sysml.runtime.matrix.data.MatrixCell;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.matrix.data.WeightedCell;
 import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
+import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysml.runtime.matrix.operators.CMOperator;
 import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
 import org.apache.sysml.runtime.matrix.operators.Operator;
@@ -177,13 +179,16 @@ public class ParameterizedBuiltinSPInstruction  extends ComputationSPInstruction
 		//opcode guaranteed to be a valid opcode (see parsing)
 		if ( opcode.equalsIgnoreCase("groupedagg") ) 
 		{	
+			boolean broadcastGroups = Boolean.parseBoolean(params.get("broadcast"));
+			
 			//get input rdd handle
+			String groupsVar = params.get(Statement.GAGG_GROUPS);
 			JavaPairRDD<MatrixIndexes,MatrixBlock> target = sec.getBinaryBlockRDDHandleForVariable( params.get(Statement.GAGG_TARGET) );
-			JavaPairRDD<MatrixIndexes,MatrixBlock> groups = sec.getBinaryBlockRDDHandleForVariable( params.get(Statement.GAGG_GROUPS) );
+			JavaPairRDD<MatrixIndexes,MatrixBlock> groups = broadcastGroups ? null : sec.getBinaryBlockRDDHandleForVariable( groupsVar );
 			JavaPairRDD<MatrixIndexes,MatrixBlock> weights = null;
 			
 			MatrixCharacteristics mc1 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_TARGET) );
-			MatrixCharacteristics mc2 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_GROUPS) );
+			MatrixCharacteristics mc2 = sec.getMatrixCharacteristics( groupsVar );
 			if(mc1.dimsKnown() && mc2.dimsKnown() && (mc1.getRows() != mc2.getRows() || mc2.getCols() !=1)) {
 				throw new DMLRuntimeException("Grouped Aggregate dimension mismatch between target and groups.");
 			}
@@ -195,7 +200,7 @@ public class ParameterizedBuiltinSPInstruction  extends ComputationSPInstruction
 			if ( params.get(Statement.GAGG_WEIGHTS) != null ) {
 				weights = sec.getBinaryBlockRDDHandleForVariable( params.get(Statement.GAGG_WEIGHTS) );
 				
-				MatrixCharacteristics mc3 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_GROUPS) );
+				MatrixCharacteristics mc3 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_WEIGHTS) );
 				if(mc1.dimsKnown() && mc3.dimsKnown() && (mc1.getRows() != mc3.getRows() || mc1.getCols() != mc3.getCols())) {
 					throw new DMLRuntimeException("Grouped Aggregate dimension mismatch between target, groups, and weights.");
 				}
@@ -205,19 +210,26 @@ public class ParameterizedBuiltinSPInstruction  extends ComputationSPInstruction
 			}
 			else //input vector or matrix
 			{
-				long ngroups = -1;
-				if ( params.get(Statement.GAGG_NUM_GROUPS) != null) {
-					ngroups = (long) Double.parseDouble(params.get(Statement.GAGG_NUM_GROUPS));
-				}
+				String ngroupsStr = params.get(Statement.GAGG_NUM_GROUPS);
+				long ngroups = (ngroupsStr != null) ? (long) Double.parseDouble(ngroupsStr) : -1;
 				
-				//replicate groups if necessary
-				if( mc1.getNumColBlocks() > 1 ) {
-					groups = groups.flatMapToPair(
+				//execute basic grouped aggregate (extract and preagg)
+				if( broadcastGroups ) {
+					PartitionedBroadcastMatrix pbm = sec.getBroadcastForVariable(groupsVar);
+					groupWeightedCells = target
+							.flatMapToPair(new ExtractGroupBroadcast(pbm, mc1.getColsPerBlock(), ngroups, _optr));						
+				}
+				else { //general case
+					
+					//replicate groups if necessary
+					if( mc1.getNumColBlocks() > 1 ) {
+						groups = groups.flatMapToPair(
 							new ReplicateVectorFunction(false, mc1.getNumColBlocks() ));
+					}
+					
+					groupWeightedCells = groups.join(target)
+							.flatMapToPair(new ExtractGroupJoin(mc1.getColsPerBlock(), ngroups, _optr));		
 				}
-				
-				groupWeightedCells = groups.join(target)
-						.flatMapToPair(new ExtractGroup(mc1.getColsPerBlock(), ngroups, _optr));
 			}
 			
 			// Step 2: Make sure we have brlen required while creating <MatrixIndexes, MatrixCell> 
@@ -228,7 +240,8 @@ public class ParameterizedBuiltinSPInstruction  extends ComputationSPInstruction
 			
 			// Step 3: Now perform grouped aggregate operation (either on combiner side or reducer side)
 			JavaPairRDD<MatrixIndexes, MatrixCell> out = null;
-			if(_optr instanceof CMOperator && ((CMOperator) _optr).isPartialAggregateOperator() ) {
+			if(_optr instanceof CMOperator && ((CMOperator) _optr).isPartialAggregateOperator() 
+				|| _optr instanceof AggregateOperator ) {
 				out = groupWeightedCells.reduceByKey(new PerformGroupByAggInCombiner(_optr))
 						.mapValues(new CreateMatrixCell(brlen, _optr));
 			}
@@ -244,8 +257,8 @@ public class ParameterizedBuiltinSPInstruction  extends ComputationSPInstruction
 			
 			//store output rdd handle
 			sec.setRDDHandleForVariable(output.getName(), out);			
-			sec.addLineageRDD(output.getName(), params.get(Statement.GAGG_TARGET) );
-			sec.addLineageRDD(output.getName(), params.get(Statement.GAGG_GROUPS) );
+			sec.addLineageRDD( output.getName(), params.get(Statement.GAGG_TARGET) );
+			sec.addLineage( output.getName(), groupsVar, broadcastGroups );
 			if ( params.get(Statement.GAGG_WEIGHTS) != null ) {
 				sec.addLineageRDD(output.getName(), params.get(Statement.GAGG_WEIGHTS) );
 			}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b308c09b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
index 283e710..e8c6dbe 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysml.runtime.instructions.spark.functions;
 
+import java.io.Serializable;
 import java.util.ArrayList;
 
 import org.apache.spark.api.java.function.PairFlatMapFunction;
@@ -26,6 +27,7 @@ import org.apache.spark.api.java.function.PairFlatMapFunction;
 import scala.Tuple2;
 
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcastMatrix;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.matrix.data.WeightedCell;
@@ -33,13 +35,13 @@ import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysml.runtime.matrix.operators.Operator;
 import org.apache.sysml.runtime.util.UtilFunctions;
 
-public class ExtractGroup  implements PairFlatMapFunction<Tuple2<MatrixIndexes,Tuple2<MatrixBlock, MatrixBlock>>, MatrixIndexes, WeightedCell> {
-
+public abstract class ExtractGroup implements Serializable 
+{
 	private static final long serialVersionUID = -7059358143841229966L;
 
-	private long _bclen = -1; 
-	private long _ngroups = -1; 
-	private Operator _op = null;
+	protected long _bclen = -1; 
+	protected long _ngroups = -1; 
+	protected Operator _op = null;
 	
 	public ExtractGroup( long bclen, long ngroups, Operator op ) {
 		_bclen = bclen;
@@ -47,15 +49,16 @@ public class ExtractGroup  implements PairFlatMapFunction<Tuple2<MatrixIndexes,T
 		_op = op;
 	}
 	
-	@Override
-	public Iterable<Tuple2<MatrixIndexes, WeightedCell>> call(
-			Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg)
-			throws Exception 
+	/**
+	 * 
+	 * @param ix
+	 * @param group
+	 * @param target
+	 * @return
+	 * @throws Exception 
+	 */
+	protected Iterable<Tuple2<MatrixIndexes, WeightedCell>> execute(MatrixIndexes ix, MatrixBlock group, MatrixBlock target) throws Exception
 	{
-		MatrixIndexes ix = arg._1;
-		MatrixBlock group = arg._2._1;
-		MatrixBlock target = arg._2._2;
-		
 		//sanity check matching block dimensions
 		if(group.getNumRows() != target.getNumRows()) {
 			throw new Exception("The blocksize for group and target blocks are mismatched: " + group.getNumRows()  + " != " + target.getNumRows());
@@ -102,6 +105,57 @@ public class ExtractGroup  implements PairFlatMapFunction<Tuple2<MatrixIndexes,T
 			}
 		}
 		
-		return groupValuePairs;
+		return groupValuePairs;	
+	}
+	
+	/**
+	 * 
+	 */
+	public static class ExtractGroupJoin extends ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes,Tuple2<MatrixBlock, MatrixBlock>>, MatrixIndexes, WeightedCell> 
+	{
+		private static final long serialVersionUID = 8890978615936560266L;
+
+		public ExtractGroupJoin(long bclen, long ngroups, Operator op) {
+			super(bclen, ngroups, op);
+		}
+		
+		@Override
+		public Iterable<Tuple2<MatrixIndexes, WeightedCell>> call(
+				Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg)
+				throws Exception 
+		{
+			MatrixIndexes ix = arg._1;
+			MatrixBlock group = arg._2._1;
+			MatrixBlock target = arg._2._2;
+	
+			return execute(ix, group, target);
+		}	
+	}
+	
+	/**
+	 * 
+	 */
+	public static class ExtractGroupBroadcast extends ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>, MatrixIndexes, WeightedCell> 
+	{
+		private static final long serialVersionUID = 5709955602290131093L;
+		
+		private PartitionedBroadcastMatrix _pbm = null;
+		
+		public ExtractGroupBroadcast( PartitionedBroadcastMatrix pbm, long bclen, long ngroups, Operator op ) {
+			super(bclen, ngroups, op);
+			_pbm = pbm;
+		}
+		
+		@Override
+		public Iterable<Tuple2<MatrixIndexes, WeightedCell>> call(
+				Tuple2<MatrixIndexes, MatrixBlock> arg)
+				throws Exception 
+		{
+			MatrixIndexes ix = arg._1;
+			MatrixBlock group = _pbm.getMatrixBlock((int)ix.getRowIndex(), 1);
+			MatrixBlock target = arg._2;
+			
+			return execute(ix, group, target);
+		}	
 	}
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b308c09b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInCombiner.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInCombiner.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInCombiner.java
index 812ff15..8b0cbc8 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInCombiner.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInCombiner.java
@@ -35,31 +35,30 @@ public class PerformGroupByAggInCombiner implements Function2<WeightedCell, Weig
 
 	private static final long serialVersionUID = -813530414567786509L;
 	
-	Operator op;
+	private Operator _op;
+	
 	public PerformGroupByAggInCombiner(Operator op) {
-		this.op = op;
+		_op = op;
 	}
 
 	@Override
-	public WeightedCell call(WeightedCell value1, WeightedCell value2) throws Exception {
-		return doAggregation(op, value1, value2);
-	}
-
-	public WeightedCell doAggregation(Operator op, WeightedCell value1, WeightedCell value2) throws DMLRuntimeException {
+	public WeightedCell call(WeightedCell value1, WeightedCell value2) 
+		throws Exception 
+	{
 		WeightedCell outCell = new WeightedCell();
 		CM_COV_Object cmObj = new CM_COV_Object(); 
-		if(op instanceof CMOperator) //everything except sum
+		if(_op instanceof CMOperator) //everything except sum
 		{
-			if( ((CMOperator) op).isPartialAggregateOperator() )
+			if( ((CMOperator) _op).isPartialAggregateOperator() )
 			{
 				cmObj.reset();
-				CM lcmFn = CM.getCMFnObject(((CMOperator) op).aggOpType); // cmFn.get(key.getTag());
+				CM lcmFn = CM.getCMFnObject(((CMOperator) _op).aggOpType); // cmFn.get(key.getTag());
 				
 				//partial aggregate cm operator
 				lcmFn.execute(cmObj, value1.getValue(), value1.getWeight());
 				lcmFn.execute(cmObj, value2.getValue(), value2.getWeight());
 				
-				outCell.setValue(cmObj.getRequiredPartialResult(op));
+				outCell.setValue(cmObj.getRequiredPartialResult(_op));
 				outCell.setWeight(cmObj.getWeight());	
 			}
 			else //forward tuples to reducer
@@ -67,9 +66,9 @@ public class PerformGroupByAggInCombiner implements Function2<WeightedCell, Weig
 				throw new DMLRuntimeException("Incorrect usage, should have used PerformGroupByAggInReducer");
 			}				
 		}
-		else if(op instanceof AggregateOperator) //sum
+		else if(_op instanceof AggregateOperator) //sum
 		{
-			AggregateOperator aggop=(AggregateOperator) op;
+			AggregateOperator aggop=(AggregateOperator) _op;
 				
 			if( aggop.correctionExists ) {
 				KahanObject buffer=new KahanObject(aggop.initialValue, 0);
@@ -96,7 +95,7 @@ public class PerformGroupByAggInCombiner implements Function2<WeightedCell, Weig
 			}				
 		}
 		else
-			throw new DMLRuntimeException("Unsupported operator in grouped aggregate instruction:" + op);
+			throw new DMLRuntimeException("Unsupported operator in grouped aggregate instruction:" + _op);
 		
 		return outCell;
 	}