You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/01/04 01:01:06 UTC
[4/4] incubator-systemml git commit: New spark map-grouped-aggregate
(compiler/runtime), for naive-bayes
New spark map-grouped-aggregate (compiler/runtime), for naive-bayes
Incl (1) refactoring for code reuse across spark/mapreduce, and (2)
additional cleanup instruction parsing (sp instruction parser).
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/ff2aea54
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/ff2aea54
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/ff2aea54
Branch: refs/heads/master
Commit: ff2aea54251948add5fd17c30a8c53536828d512
Parents: 3ea3cdb
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sat Jan 2 18:22:44 2016 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sat Jan 2 18:22:44 2016 -0800
----------------------------------------------------------------------
.../sysml/hops/ParameterizedBuiltinOp.java | 22 +-
.../apache/sysml/lops/GroupedAggregateM.java | 35 +++-
.../instructions/SPInstructionParser.java | 39 ++--
.../mr/GroupedAggregateMInstruction.java | 33 +--
.../spark/AppendGAlignedSPInstruction.java | 3 +-
.../spark/AppendGSPInstruction.java | 3 +-
.../spark/AppendMSPInstruction.java | 3 +-
.../spark/AppendRSPInstruction.java | 3 +-
.../spark/BuiltinBinarySPInstruction.java | 3 +-
.../spark/BuiltinUnarySPInstruction.java | 3 +-
.../spark/MatrixReshapeSPInstruction.java | 3 +-
.../ParameterizedBuiltinSPInstruction.java | 203 +++++++++++++------
.../spark/QuantilePickSPInstruction.java | 3 +-
.../spark/QuantileSortSPInstruction.java | 3 +-
.../instructions/spark/RandSPInstruction.java | 3 +-
.../instructions/spark/WriteSPInstruction.java | 3 +-
.../matrix/data/OperationsOnMatrixValues.java | 41 ++++
17 files changed, 267 insertions(+), 139 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
index 466d008..3a8445f 100644
--- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
@@ -375,7 +375,7 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop
}
else //CP/Spark
{
- GroupedAggregate grp_agg = null;
+ Lop grp_agg = null;
if( et == ExecType.CP)
{
@@ -391,9 +391,23 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop
OptimizerUtils.checkSparkBroadcastMemoryBudget( groups.getDim1(), groups.getDim2(),
groups.getRowsInBlock(), groups.getColsInBlock(), groups.getNnz()) );
- grp_agg = new GroupedAggregate(inputlops, getDataType(), getValueType(), et, broadcastGroups);
- grp_agg.getOutputParameters().setDimensions(outputDim1, outputDim2, -1, -1, -1);
- setRequiresReblock( true );
+ if( broadcastGroups //mapgroupedagg
+ && getInput().get(_paramIndexMap.get(Statement.GAGG_FN)) instanceof LiteralOp
+ && ((LiteralOp)getInput().get(_paramIndexMap.get(Statement.GAGG_FN))).getStringValue().equals("sum")
+ && inputlops.get(Statement.GAGG_NUM_GROUPS) != null )
+ {
+ Hop target = getInput().get(_paramIndexMap.get(Statement.GAGG_TARGET));
+
+ grp_agg = new GroupedAggregateM(inputlops, getDataType(), getValueType(), true, ExecType.SPARK);
+ grp_agg.getOutputParameters().setDimensions(outputDim1, outputDim2, target.getRowsInBlock(), target.getColsInBlock(), -1);
+ //no reblock required (directly output binary block)
+ }
+ else //groupedagg (w/ or w/o broadcast)
+ {
+ grp_agg = new GroupedAggregate(inputlops, getDataType(), getValueType(), et, broadcastGroups);
+ grp_agg.getOutputParameters().setDimensions(outputDim1, outputDim2, -1, -1, -1);
+ setRequiresReblock( true );
+ }
}
setLineNumbers(grp_agg);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/lops/GroupedAggregateM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/GroupedAggregateM.java b/src/main/java/org/apache/sysml/lops/GroupedAggregateM.java
index 5edba62..2d73abc 100644
--- a/src/main/java/org/apache/sysml/lops/GroupedAggregateM.java
+++ b/src/main/java/org/apache/sysml/lops/GroupedAggregateM.java
@@ -68,13 +68,25 @@ public class GroupedAggregateM extends Lop
addInput(inputParameterLops.get(Statement.GAGG_GROUPS));
inputParameterLops.get(Statement.GAGG_GROUPS).addOutput(this);
- //setup MR parameters
- boolean breaksAlignment = true;
- boolean aligner = false;
- boolean definesMRJob = false;
- lps.addCompatibility(JobType.GMR);
- lps.addCompatibility(JobType.DATAGEN);
- lps.setProperties( inputs, ExecType.MR, ExecLocation.Map, breaksAlignment, aligner, definesMRJob );
+ if( et == ExecType.MR )
+ {
+ //setup MR parameters
+ boolean breaksAlignment = true;
+ boolean aligner = false;
+ boolean definesMRJob = false;
+ lps.addCompatibility(JobType.GMR);
+ lps.addCompatibility(JobType.DATAGEN);
+ lps.setProperties( inputs, ExecType.MR, ExecLocation.Map, breaksAlignment, aligner, definesMRJob );
+ }
+ else //SPARK
+ {
+ //setup Spark parameters
+ boolean breaksAlignment = false;
+ boolean aligner = false;
+ boolean definesMRJob = false;
+ lps.addCompatibility(JobType.INVALID);
+ lps.setProperties( inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
+ }
}
@Override
@@ -85,6 +97,15 @@ public class GroupedAggregateM extends Lop
@Override
public String getInstructions(int input1, int input2, int output)
{
+ return getInstructions(
+ String.valueOf(input1),
+ String.valueOf(input2),
+ String.valueOf(output) );
+ }
+
+ @Override
+ public String getInstructions(String input1, String input2, String output)
+ {
StringBuilder sb = new StringBuilder();
sb.append( getExecType() );
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
index 1f5f961..2683c43 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
@@ -201,11 +201,12 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "sel+", SPINSTRUCTION_TYPE.BuiltinUnary);
// Parameterized Builtin Functions
- String2SPInstructionType.put( "groupedagg" , SPINSTRUCTION_TYPE.ParameterizedBuiltin);
- String2SPInstructionType.put( "rmempty" , SPINSTRUCTION_TYPE.ParameterizedBuiltin);
- String2SPInstructionType.put( "replace" , SPINSTRUCTION_TYPE.ParameterizedBuiltin);
- String2SPInstructionType.put( "rexpand" , SPINSTRUCTION_TYPE.ParameterizedBuiltin);
- String2SPInstructionType.put( "transform" , SPINSTRUCTION_TYPE.ParameterizedBuiltin);
+ String2SPInstructionType.put( "groupedagg" , SPINSTRUCTION_TYPE.ParameterizedBuiltin);
+ String2SPInstructionType.put( "mapgroupedagg", SPINSTRUCTION_TYPE.ParameterizedBuiltin);
+ String2SPInstructionType.put( "rmempty" , SPINSTRUCTION_TYPE.ParameterizedBuiltin);
+ String2SPInstructionType.put( "replace" , SPINSTRUCTION_TYPE.ParameterizedBuiltin);
+ String2SPInstructionType.put( "rexpand" , SPINSTRUCTION_TYPE.ParameterizedBuiltin);
+ String2SPInstructionType.put( "transform" , SPINSTRUCTION_TYPE.ParameterizedBuiltin);
String2SPInstructionType.put( "mappend", SPINSTRUCTION_TYPE.MAppend);
String2SPInstructionType.put( "rappend", SPINSTRUCTION_TYPE.RAppend);
@@ -338,10 +339,10 @@ public class SPInstructionParser extends InstructionParser
if ( parts[0].equals("log") || parts[0].equals("log_nz") ) {
if ( parts.length == 3 ) {
// B=log(A), y=log(x)
- return (SPInstruction) BuiltinUnarySPInstruction.parseInstruction(str);
+ return BuiltinUnarySPInstruction.parseInstruction(str);
} else if ( parts.length == 4 ) {
// B=log(A,10), y=log(x,10)
- return (SPInstruction) BuiltinBinarySPInstruction.parseInstruction(str);
+ return BuiltinBinarySPInstruction.parseInstruction(str);
}
}
else {
@@ -349,40 +350,40 @@ public class SPInstructionParser extends InstructionParser
}
case BuiltinBinary:
- return (SPInstruction) BuiltinBinarySPInstruction.parseInstruction(str);
+ return BuiltinBinarySPInstruction.parseInstruction(str);
case BuiltinUnary:
- return (SPInstruction) BuiltinUnarySPInstruction.parseInstruction(str);
+ return BuiltinUnarySPInstruction.parseInstruction(str);
case ParameterizedBuiltin:
- return (SPInstruction) ParameterizedBuiltinSPInstruction.parseInstruction(str);
+ return ParameterizedBuiltinSPInstruction.parseInstruction(str);
case MatrixReshape:
- return (SPInstruction) MatrixReshapeSPInstruction.parseInstruction(str);
+ return MatrixReshapeSPInstruction.parseInstruction(str);
case MAppend:
- return (SPInstruction) AppendMSPInstruction.parseInstruction(str);
+ return AppendMSPInstruction.parseInstruction(str);
case GAppend:
- return (SPInstruction) AppendGSPInstruction.parseInstruction(str);
+ return AppendGSPInstruction.parseInstruction(str);
case GAlignedAppend:
- return (SPInstruction) AppendGAlignedSPInstruction.parseInstruction(str);
+ return AppendGAlignedSPInstruction.parseInstruction(str);
case RAppend:
- return (SPInstruction) AppendRSPInstruction.parseInstruction(str);
+ return AppendRSPInstruction.parseInstruction(str);
case Rand:
- return (SPInstruction) RandSPInstruction.parseInstruction(str);
+ return RandSPInstruction.parseInstruction(str);
case QSort:
- return (SPInstruction) QuantileSortSPInstruction.parseInstruction(str);
+ return QuantileSortSPInstruction.parseInstruction(str);
case QPick:
- return (SPInstruction) QuantilePickSPInstruction.parseInstruction(str);
+ return QuantilePickSPInstruction.parseInstruction(str);
case Write:
- return (SPInstruction) WriteSPInstruction.parseInstruction(str);
+ return WriteSPInstruction.parseInstruction(str);
case CumsumAggregate:
return CumulativeAggregateSPInstruction.parseInstruction(str);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/mr/GroupedAggregateMInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/GroupedAggregateMInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/GroupedAggregateMInstruction.java
index 4130fd9..34f0945 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/mr/GroupedAggregateMInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/GroupedAggregateMInstruction.java
@@ -30,6 +30,7 @@ import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
+import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.mapred.CachedValueMap;
import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
@@ -90,37 +91,19 @@ public class GroupedAggregateMInstruction extends BinaryMRInstructionBase implem
//get all inputs
MatrixIndexes ix = in1.getIndexes();
- MatrixBlock target = (MatrixBlock)in1.getValue();
MatrixBlock groups = (MatrixBlock)dcInput.getDataBlock((int)ix.getRowIndex(), 1).getValue();
- //execute grouped aggregate operations
- MatrixBlock out = groups.groupedAggOperations(target, null, new MatrixBlock(), _ngroups, getOperator());
-
//output blocked result
int brlen = dcInput.getNumRowsPerBlock();
int bclen = dcInput.getNumColsPerBlock();
- if( out.getNumRows()<=brlen && out.getNumColumns()<=bclen )
- {
- //single output block
- cachedValues.add(output, new IndexedMatrixValue(new MatrixIndexes(1,ix.getColumnIndex()), out));
- }
- else
- {
- //multiple output blocks (by op def, single column block )
- for(int blockRow = 0; blockRow < (int)Math.ceil(out.getNumRows()/(double)brlen); blockRow++)
- {
- int maxRow = (blockRow*brlen + brlen < out.getNumRows()) ? brlen : out.getNumRows() - blockRow*brlen;
- int row_offset = blockRow*brlen;
-
- //copy submatrix to block
- MatrixBlock tmp = out.sliceOperations( row_offset, row_offset+maxRow-1,
- 0, out.getNumColumns()-1, new MatrixBlock() );
-
- //append block to result cache
- cachedValues.add(output, new IndexedMatrixValue(
- new MatrixIndexes(blockRow+1,ix.getColumnIndex()), tmp));
- }
+ //execute map grouped aggregate operations
+ ArrayList<IndexedMatrixValue> outlist = new ArrayList<IndexedMatrixValue>();
+ OperationsOnMatrixValues.performMapGroupedAggregate(getOperator(), in1, groups, _ngroups, brlen, bclen, outlist);
+
+ //output all result blocks
+ for( IndexedMatrixValue out : outlist ) {
+ cachedValues.add(output, out);
}
}
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGAlignedSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGAlignedSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGAlignedSPInstruction.java
index 5a930da..f1ddcdf 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGAlignedSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGAlignedSPInstruction.java
@@ -29,7 +29,6 @@ import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.OffsetColumnIndex;
-import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -49,7 +48,7 @@ public class AppendGAlignedSPInstruction extends BinarySPInstruction
_cbind = cbind;
}
- public static Instruction parseInstruction ( String str )
+ public static AppendGAlignedSPInstruction parseInstruction ( String str )
throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGSPInstruction.java
index ecf2bc3..30ca4f7 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendGSPInstruction.java
@@ -33,7 +33,6 @@ import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.OffsetColumnIndex;
-import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -54,7 +53,7 @@ public class AppendGSPInstruction extends BinarySPInstruction
_cbind = cbind;
}
- public static Instruction parseInstruction ( String str )
+ public static AppendGSPInstruction parseInstruction ( String str )
throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendMSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendMSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendMSPInstruction.java
index a1a0bf3..245f234 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendMSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendMSPInstruction.java
@@ -32,7 +32,6 @@ import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.OffsetColumnIndex;
-import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.data.LazyIterableIterator;
@@ -59,7 +58,7 @@ public class AppendMSPInstruction extends BinarySPInstruction
_cbind = cbind;
}
- public static Instruction parseInstruction ( String str )
+ public static AppendMSPInstruction parseInstruction ( String str )
throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendRSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendRSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendRSPInstruction.java
index 7beba7d..e9e6e35 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendRSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/AppendRSPInstruction.java
@@ -29,7 +29,6 @@ import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.functionobjects.OffsetColumnIndex;
-import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -48,7 +47,7 @@ public class AppendRSPInstruction extends BinarySPInstruction
_cbind = cbind;
}
- public static Instruction parseInstruction ( String str )
+ public static AppendRSPInstruction parseInstruction ( String str )
throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinBinarySPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinBinarySPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinBinarySPInstruction.java
index da85a89..01657ec 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinBinarySPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinBinarySPInstruction.java
@@ -27,7 +27,6 @@ import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.ValueFunction;
-import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
@@ -50,7 +49,7 @@ public abstract class BuiltinBinarySPInstruction extends BinarySPInstruction
* @throws DMLRuntimeException
* @throws DMLUnsupportedOperationException
*/
- public static Instruction parseInstruction ( String str )
+ public static BuiltinBinarySPInstruction parseInstruction ( String str )
throws DMLRuntimeException, DMLUnsupportedOperationException
{
CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinUnarySPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinUnarySPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinUnarySPInstruction.java
index c9d5ab6..0bd272c 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinUnarySPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/BuiltinUnarySPInstruction.java
@@ -25,7 +25,6 @@ import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.ValueFunction;
-import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
@@ -47,7 +46,7 @@ public abstract class BuiltinUnarySPInstruction extends UnarySPInstruction
* @throws DMLRuntimeException
* @throws DMLUnsupportedOperationException
*/
- public static Instruction parseInstruction ( String str )
+ public static BuiltinUnarySPInstruction parseInstruction ( String str )
throws DMLRuntimeException, DMLUnsupportedOperationException
{
CPOperand in = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixReshapeSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixReshapeSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixReshapeSPInstruction.java
index 5d30c94..c2d180e 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixReshapeSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/MatrixReshapeSPInstruction.java
@@ -31,7 +31,6 @@ import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
@@ -70,7 +69,7 @@ public class MatrixReshapeSPInstruction extends UnarySPInstruction
* @return
* @throws DMLRuntimeException
*/
- public static Instruction parseInstruction ( String str )
+ public static MatrixReshapeSPInstruction parseInstruction ( String str )
throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 505e232..f8e2669 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -29,6 +29,7 @@ import org.apache.spark.api.java.function.PairFlatMapFunction;
import scala.Tuple2;
import org.apache.sysml.lops.Lop;
+import org.apache.sysml.lops.PartialAggregate.CorrectionLocationType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.parser.ParameterizedBuiltinFunctionExpression;
import org.apache.sysml.parser.Statement;
@@ -37,9 +38,9 @@ import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.functionobjects.ParameterizedBuiltin;
import org.apache.sysml.runtime.functionobjects.ValueFunction;
-import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction;
@@ -57,6 +58,7 @@ import org.apache.sysml.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
import org.apache.sysml.runtime.matrix.mapred.IndexedMatrixValue;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
@@ -68,10 +70,10 @@ import org.apache.sysml.runtime.transform.DataTransform;
import org.apache.sysml.runtime.util.UtilFunctions;
public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
-{
-
- private int arity;
+{
protected HashMap<String,String> params;
+
+ //removeEmpty-specific attributes
private boolean _bRmEmptyBC = false;
public ParameterizedBuiltinSPInstruction(Operator op, HashMap<String,String> paramsMap, CPOperand out, String opcode, String istr, boolean bRmEmptyBC )
@@ -82,10 +84,6 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
_bRmEmptyBC = bRmEmptyBC;
}
- public int getArity() {
- return arity;
- }
-
public HashMap<String,String> getParams() { return params; }
public static HashMap<String, String> constructParameterMap(String[] params) {
@@ -102,70 +100,89 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
return paramMap;
}
- public static Instruction parseInstruction ( String str )
+ public static ParameterizedBuiltinSPInstruction parseInstruction ( String str )
throws DMLRuntimeException, DMLUnsupportedOperationException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
// first part is always the opcode
String opcode = parts[0];
- // last part is always the output
- CPOperand out = new CPOperand( parts[parts.length-1] );
- // process remaining parts and build a hash map
- HashMap<String,String> paramsMap = constructParameterMap(parts);
+ if( opcode.equalsIgnoreCase("mapgroupedagg") )
+ {
+ CPOperand target = new CPOperand( parts[1] );
+ CPOperand groups = new CPOperand( parts[2] );
+ CPOperand out = new CPOperand( parts[3] );
- // determine the appropriate value function
- ValueFunction func = null;
- if ( opcode.equalsIgnoreCase("groupedagg")) {
- // check for mandatory arguments
- String fnStr = paramsMap.get("fn");
- if ( fnStr == null )
- throw new DMLRuntimeException("Function parameter is missing in groupedAggregate.");
- if ( fnStr.equalsIgnoreCase("centralmoment") ) {
- if ( paramsMap.get("order") == null )
- throw new DMLRuntimeException("Mandatory \"order\" must be specified when fn=\"centralmoment\" in groupedAggregate.");
- }
+ HashMap<String,String> paramsMap = new HashMap<String, String>();
+ paramsMap.put(Statement.GAGG_TARGET, target.getName());
+ paramsMap.put(Statement.GAGG_GROUPS, groups.getName());
+ paramsMap.put(Statement.GAGG_NUM_GROUPS, parts[4]);
- Operator op = GroupedAggregateInstruction.parseGroupedAggOperator(fnStr, paramsMap.get("order"));
- return new ParameterizedBuiltinSPInstruction(op, paramsMap, out, opcode, str, false);
- }
- else if( opcode.equalsIgnoreCase("rmempty") )
- {
- boolean bRmEmptyBC = false;
- if(parts.length > 6)
- bRmEmptyBC = (parts[5].compareTo("true") == 0)?true:false;
-
- func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
- return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, bRmEmptyBC);
- }
- else if( opcode.equalsIgnoreCase("rexpand") )
- {
- func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
- return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, false);
- }
- else if( opcode.equalsIgnoreCase("replace") )
- {
- func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
- return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, false);
+ Operator op = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
+
+ return new ParameterizedBuiltinSPInstruction(op, paramsMap, out, opcode, str, false);
}
- else if ( opcode.equalsIgnoreCase("transform") )
+ else
{
- func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
- String specFile = paramsMap.get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_TXSPEC);
- String applyTxPath = paramsMap.get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_APPLYMTD);
- if ( specFile != null && applyTxPath != null)
- throw new DMLRuntimeException(
- "Invalid parameters to transform(). Only one of '"
- + ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_TXSPEC
- + "' or '"
- + ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_APPLYMTD
- + "' can be specified.");
- return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, false);
- }
- else {
- throw new DMLRuntimeException("Unknown opcode (" + opcode + ") for ParameterizedBuiltin Instruction.");
- }
+ // last part is always the output
+ CPOperand out = new CPOperand( parts[parts.length-1] );
+ // process remaining parts and build a hash map
+ HashMap<String,String> paramsMap = constructParameterMap(parts);
+
+ // determine the appropriate value function
+ ValueFunction func = null;
+
+ if ( opcode.equalsIgnoreCase("groupedagg")) {
+ // check for mandatory arguments
+ String fnStr = paramsMap.get("fn");
+ if ( fnStr == null )
+ throw new DMLRuntimeException("Function parameter is missing in groupedAggregate.");
+ if ( fnStr.equalsIgnoreCase("centralmoment") ) {
+ if ( paramsMap.get("order") == null )
+ throw new DMLRuntimeException("Mandatory \"order\" must be specified when fn=\"centralmoment\" in groupedAggregate.");
+ }
+
+ Operator op = GroupedAggregateInstruction.parseGroupedAggOperator(fnStr, paramsMap.get("order"));
+ return new ParameterizedBuiltinSPInstruction(op, paramsMap, out, opcode, str, false);
+ }
+ else if( opcode.equalsIgnoreCase("rmempty") )
+ {
+ boolean bRmEmptyBC = false;
+ if(parts.length > 6)
+ bRmEmptyBC = (parts[5].compareTo("true") == 0)?true:false;
+
+ func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
+ return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, bRmEmptyBC);
+ }
+ else if( opcode.equalsIgnoreCase("rexpand") )
+ {
+ func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
+ return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, false);
+ }
+ else if( opcode.equalsIgnoreCase("replace") )
+ {
+ func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
+ return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, false);
+ }
+ else if ( opcode.equalsIgnoreCase("transform") )
+ {
+ func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
+ String specFile = paramsMap.get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_TXSPEC);
+ String applyTxPath = paramsMap.get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_APPLYMTD);
+ if ( specFile != null && applyTxPath != null)
+ throw new DMLRuntimeException(
+ "Invalid parameters to transform(). Only one of '"
+ + ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_TXSPEC
+ + "' or '"
+ + ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_APPLYMTD
+ + "' can be specified.");
+ return new ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, opcode, str, false);
+ }
+ else {
+ throw new DMLRuntimeException("Unknown opcode (" + opcode + ") for ParameterizedBuiltin Instruction.");
+ }
+ }
}
@@ -177,7 +194,31 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
String opcode = getOpcode();
//opcode guaranteed to be a valid opcode (see parsing)
- if ( opcode.equalsIgnoreCase("groupedagg") )
+ if( opcode.equalsIgnoreCase("mapgroupedagg") )
+ {
+ //get input rdd handle
+ String targetVar = params.get(Statement.GAGG_TARGET);
+ String groupsVar = params.get(Statement.GAGG_GROUPS);
+ JavaPairRDD<MatrixIndexes,MatrixBlock> target = sec.getBinaryBlockRDDHandleForVariable(targetVar);
+ PartitionedBroadcastMatrix groups = sec.getBroadcastForVariable(groupsVar);
+ MatrixCharacteristics mc1 = sec.getMatrixCharacteristics( targetVar );
+ MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
+ CPOperand ngrpOp = new CPOperand(params.get(Statement.GAGG_NUM_GROUPS));
+ int ngroups = (int)sec.getScalarInput(ngrpOp.getName(), ngrpOp.getValueType(), ngrpOp.isLiteral()).getLongValue();
+
+ //execute map grouped aggregate
+ JavaPairRDD<MatrixIndexes, MatrixBlock> out =
+ target.flatMapToPair(new RDDMapGroupedAggFunction(groups, _optr,
+ ngroups, mc1.getRowsPerBlock(), mc1.getColsPerBlock()));
+ out = RDDAggregateUtils.sumByKeyStable(out);
+
+ //updated characteristics and handle outputs
+ mcOut.set(ngroups, mc1.getCols(), mc1.getRowsPerBlock(), mc1.getColsPerBlock(), -1);
+ sec.setRDDHandleForVariable(output.getName(), out);
+ sec.addLineageRDD( output.getName(), targetVar );
+ sec.addLineageBroadcast( output.getName(), groupsVar );
+ }
+ else if ( opcode.equalsIgnoreCase("groupedagg") )
{
boolean broadcastGroups = Boolean.parseBoolean(params.get("broadcast"));
@@ -519,6 +560,44 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
}
}
+ public static class RDDMapGroupedAggFunction implements PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>,MatrixIndexes,MatrixBlock>
+ {
+ private static final long serialVersionUID = 6795402640178679851L;
+
+ private PartitionedBroadcastMatrix _pbm = null;
+ private Operator _op = null;
+ private int _ngroups = -1;
+ private int _brlen = -1;
+ private int _bclen = -1;
+
+ public RDDMapGroupedAggFunction(PartitionedBroadcastMatrix pbm, Operator op, int ngroups, int brlen, int bclen)
+ {
+ _pbm = pbm;
+ _op = op;
+ _ngroups = ngroups;
+ _brlen = brlen;
+ _bclen = bclen;
+ }
+
+ @Override
+ public Iterable<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<MatrixIndexes, MatrixBlock> arg0)
+ throws Exception
+ {
+ //get all inputs
+ MatrixIndexes ix = arg0._1();
+ MatrixBlock target = arg0._2();
+ MatrixBlock groups = _pbm.getMatrixBlock((int)ix.getRowIndex(), 1);
+
+ //execute map grouped aggregate operations
+ IndexedMatrixValue in1 = SparkUtils.toIndexedMatrixBlock(ix, target);
+ ArrayList<IndexedMatrixValue> outlist = new ArrayList<IndexedMatrixValue>();
+ OperationsOnMatrixValues.performMapGroupedAggregate(_op, in1, groups, _ngroups, _brlen, _bclen, outlist);
+
+ //output all result blocks
+ return SparkUtils.fromIndexedMatrixBlock(outlist);
+ }
+ }
+
/**
*
*/
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java
index a4f7a83..5cea24f 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantilePickSPInstruction.java
@@ -32,7 +32,6 @@ import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.DoubleObject;
@@ -66,7 +65,7 @@ public class QuantilePickSPInstruction extends BinarySPInstruction
* @return
* @throws DMLRuntimeException
*/
- public static Instruction parseInstruction ( String str )
+ public static QuantilePickSPInstruction parseInstruction ( String str )
throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantileSortSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantileSortSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantileSortSPInstruction.java
index abf837b..793354f 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantileSortSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/QuantileSortSPInstruction.java
@@ -28,7 +28,6 @@ import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.utils.RDDSortUtils;
@@ -59,7 +58,7 @@ public class QuantileSortSPInstruction extends UnarySPInstruction
_sptype = SPINSTRUCTION_TYPE.QSort;
}
- public static Instruction parseInstruction ( String str )
+ public static QuantileSortSPInstruction parseInstruction ( String str )
throws DMLRuntimeException {
CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
CPOperand in2 = null;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/spark/RandSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/RandSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/RandSPInstruction.java
index 3b4f798..b1eca1e 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/RandSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/RandSPInstruction.java
@@ -45,7 +45,6 @@ import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
-import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
@@ -197,7 +196,7 @@ public class RandSPInstruction extends UnarySPInstruction
* @return
* @throws DMLRuntimeException
*/
- public static Instruction parseInstruction(String str)
+ public static RandSPInstruction parseInstruction(String str)
throws DMLRuntimeException
{
String[] s = InstructionUtils.getInstructionPartsWithValueType ( str );
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
index 47cdfd1..3a6ad01 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
@@ -34,7 +34,6 @@ import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.functions.ComputeBinaryBlockNnzFunction;
@@ -72,7 +71,7 @@ public class WriteSPInstruction extends SPInstruction
formatProperties = null; // set in case of csv
}
- public static Instruction parseInstruction ( String str )
+ public static WriteSPInstruction parseInstruction ( String str )
throws DMLRuntimeException
{
String[] parts = InstructionUtils.getInstructionPartsWithValueType ( str );
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/ff2aea54/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java b/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java
index 65e2c22..485b559 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java
@@ -350,4 +350,45 @@ public class OperationsOnMatrixValues
//execute actual slice operation
in.getValue().sliceOperations(outlist, tmpRange, rowCut, colCut, brlen, bclen, boundaryRlen, boundaryClen);
}
+
+ /**
+ *
+ * @param target
+ * @param groups
+ * @param brlen
+ * @param bclen
+ * @param outlist
+ * @throws DMLRuntimeException
+ * @throws DMLUnsupportedOperationException
+ */
+ public static void performMapGroupedAggregate( Operator op, IndexedMatrixValue inTarget, MatrixBlock groups, int ngroups, int brlen, int bclen, ArrayList<IndexedMatrixValue> outlist ) throws DMLRuntimeException, DMLUnsupportedOperationException
+ {
+ MatrixIndexes ix = inTarget.getIndexes();
+ MatrixBlock target = (MatrixBlock)inTarget.getValue();
+
+ //execute grouped aggregate operations
+ MatrixBlock out = groups.groupedAggOperations(target, null, new MatrixBlock(), ngroups, op);
+
+ if( out.getNumRows()<=brlen && out.getNumColumns()<=bclen )
+ {
+ //single output block
+ outlist.add( new IndexedMatrixValue(new MatrixIndexes(1,ix.getColumnIndex()), out) );
+ }
+ else
+ {
+ //multiple output blocks (by op def, single column block )
+ for(int blockRow = 0; blockRow < (int)Math.ceil(out.getNumRows()/(double)brlen); blockRow++)
+ {
+ int maxRow = (blockRow*brlen + brlen < out.getNumRows()) ? brlen : out.getNumRows() - blockRow*brlen;
+ int row_offset = blockRow*brlen;
+
+ //copy submatrix to block
+ MatrixBlock tmp = out.sliceOperations( row_offset, row_offset+maxRow-1,
+ 0, out.getNumColumns()-1, new MatrixBlock() );
+
+ //append block to result cache
+ outlist.add(new IndexedMatrixValue(new MatrixIndexes(blockRow+1,ix.getColumnIndex()), tmp));
+ }
+ }
+ }
}