You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/02/27 18:36:06 UTC
[4/9] incubator-systemml git commit: [SYSTEMML-1287] Code generator
runtime integration
[SYSTEMML-1287] Code generator runtime integration
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/982ecb1a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/982ecb1a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/982ecb1a
Branch: refs/heads/master
Commit: 982ecb1a4be69685a8e124eccfa3a12331f998b0
Parents: d7fd587
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun Feb 26 19:01:36 2017 -0800
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sun Feb 26 19:01:36 2017 -0800
----------------------------------------------------------------------
.../instructions/CPInstructionParser.java | 19 +-
.../instructions/SPInstructionParser.java | 14 +-
.../runtime/instructions/cp/CPInstruction.java | 8 +-
.../instructions/cp/SpoofCPInstruction.java | 98 +++++
.../instructions/spark/SPInstruction.java | 2 +-
.../instructions/spark/SpoofSPInstruction.java | 407 +++++++++++++++++++
.../spark/utils/RDDAggregateUtils.java | 8 +-
7 files changed, 541 insertions(+), 15 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index f3c1605..f0603b4 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -61,6 +61,7 @@ import org.apache.sysml.runtime.instructions.cp.QuantileSortCPInstruction;
import org.apache.sysml.runtime.instructions.cp.QuaternaryCPInstruction;
import org.apache.sysml.runtime.instructions.cp.RelationalBinaryCPInstruction;
import org.apache.sysml.runtime.instructions.cp.ReorgCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.SpoofCPInstruction;
import org.apache.sysml.runtime.instructions.cp.StringInitCPInstruction;
import org.apache.sysml.runtime.instructions.cp.TernaryCPInstruction;
import org.apache.sysml.runtime.instructions.cp.UaggOuterChainCPInstruction;
@@ -271,8 +272,9 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "lu", CPINSTRUCTION_TYPE.MultiReturnBuiltin);
String2CPInstructionType.put( "eigen", CPINSTRUCTION_TYPE.MultiReturnBuiltin);
- String2CPInstructionType.put( "partition", CPINSTRUCTION_TYPE.Partition);
- String2CPInstructionType.put( "compress", CPINSTRUCTION_TYPE.Compression);
+ String2CPInstructionType.put( "partition", CPINSTRUCTION_TYPE.Partition);
+ String2CPInstructionType.put( "compress", CPINSTRUCTION_TYPE.Compression);
+ String2CPInstructionType.put( "spoof", CPINSTRUCTION_TYPE.SpoofFused);
//CP FILE instruction
String2CPFileInstructionType = new HashMap<String, CPINSTRUCTION_TYPE>();
@@ -424,16 +426,19 @@ public class CPInstructionParser extends InstructionParser
case Partition:
return DataPartitionCPInstruction.parseInstruction(str);
-
- case Compression:
- return (CPInstruction) CompressionCPInstruction.parseInstruction(str);
-
+
case CentralMoment:
return CentralMomentCPInstruction.parseInstruction(str);
case Covariance:
return CovarianceCPInstruction.parseInstruction(str);
-
+
+ case Compression:
+ return (CPInstruction) CompressionCPInstruction.parseInstruction(str);
+
+ case SpoofFused:
+ return SpoofCPInstruction.parseInstruction(str);
+
case INVALID:
default:
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
index 6658a88..5ca3847 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
@@ -73,6 +73,7 @@ import org.apache.sysml.runtime.instructions.spark.ReorgSPInstruction;
import org.apache.sysml.runtime.instructions.spark.RmmSPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction.SPINSTRUCTION_TYPE;
+import org.apache.sysml.runtime.instructions.spark.SpoofSPInstruction;
import org.apache.sysml.runtime.instructions.spark.TernarySPInstruction;
import org.apache.sysml.runtime.instructions.spark.Tsmm2SPInstruction;
import org.apache.sysml.runtime.instructions.spark.TsmmSPInstruction;
@@ -277,10 +278,12 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "binuaggchain", SPINSTRUCTION_TYPE.BinUaggChain);
- String2SPInstructionType.put( "write" , SPINSTRUCTION_TYPE.Write);
+ String2SPInstructionType.put( "write" , SPINSTRUCTION_TYPE.Write);
- String2SPInstructionType.put( "castdtm" , SPINSTRUCTION_TYPE.Cast);
- String2SPInstructionType.put( "castdtf" , SPINSTRUCTION_TYPE.Cast);
+ String2SPInstructionType.put( "castdtm" , SPINSTRUCTION_TYPE.Cast);
+ String2SPInstructionType.put( "castdtf" , SPINSTRUCTION_TYPE.Cast);
+
+ String2SPInstructionType.put( "spoof" , SPINSTRUCTION_TYPE.SpoofFused);
}
public static SPInstruction parseSingleInstruction (String str )
@@ -443,10 +446,13 @@ public class SPInstructionParser extends InstructionParser
case Checkpoint:
return CheckpointSPInstruction.parseInstruction(str);
-
+
case Compression:
return CompressionSPInstruction.parseInstruction(str);
+ case SpoofFused:
+ return SpoofSPInstruction.parseInstruction(str);
+
case Cast:
return CastSPInstruction.parseInstruction(str);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java
index 1d192d5..dcd8d89 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CPInstruction.java
@@ -29,7 +29,13 @@ import org.apache.sysml.runtime.matrix.operators.Operator;
public abstract class CPInstruction extends Instruction
{
- public enum CPINSTRUCTION_TYPE { INVALID, AggregateUnary, AggregateBinary, AggregateTernary, ArithmeticBinary, Ternary, Quaternary, BooleanBinary, BooleanUnary, BuiltinBinary, BuiltinUnary, BuiltinMultiple, MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin, Builtin, Reorg, RelationalBinary, File, Variable, External, Append, Rand, QSort, QPick, MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, Compression, StringInit, CentralMoment, Covariance, UaggOuterChain, Convolution };
+ public enum CPINSTRUCTION_TYPE { INVALID,
+ AggregateUnary, AggregateBinary, AggregateTernary, ArithmeticBinary,
+ Ternary, Quaternary, BooleanBinary, BooleanUnary, BuiltinBinary, BuiltinUnary,
+ BuiltinMultiple, MultiReturnParameterizedBuiltin, ParameterizedBuiltin, MultiReturnBuiltin,
+ Builtin, Reorg, RelationalBinary, File, Variable, External, Append, Rand, QSort, QPick,
+ MatrixIndexing, MMTSJ, PMMJ, MMChain, MatrixReshape, Partition, Compression, SpoofFused,
+ StringInit, CentralMoment, Covariance, UaggOuterChain, Convolution };
protected CPINSTRUCTION_TYPE _cptype;
protected Operator _optr;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java
new file mode 100644
index 0000000..61313d7
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/SpoofCPInstruction.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.cp;
+
+import java.util.ArrayList;
+
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.codegen.CodegenUtils;
+import org.apache.sysml.runtime.codegen.SpoofOperator;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.ComputationCPInstruction;
+import org.apache.sysml.runtime.instructions.cp.ScalarObject;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+public class SpoofCPInstruction extends ComputationCPInstruction
+{
+ private Class<?> _class = null;
+ private int _numThreads = 1;
+ private CPOperand[] _in = null;
+
+ public SpoofCPInstruction(Class<?> cla, int k, CPOperand[] in, CPOperand out, String opcode, String str) {
+ super(null, null, null, out, opcode, str);
+ _class = cla;
+ _numThreads = k;
+ _in = in;
+ }
+
+ public static SpoofCPInstruction parseInstruction(String str)
+ throws DMLRuntimeException
+ {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+
+ //String opcode = parts[0];
+ ArrayList<CPOperand> inlist = new ArrayList<CPOperand>();
+ Class<?> cla = CodegenUtils.loadClass(parts[1], null);
+ String opcode = parts[0] + CodegenUtils.getSpoofType(cla);
+
+ for( int i=2; i<parts.length-2; i++ )
+ inlist.add(new CPOperand(parts[i]));
+ CPOperand out = new CPOperand(parts[parts.length-2]);
+ int k = Integer.parseInt(parts[parts.length-1]);
+
+ return new SpoofCPInstruction(cla, k, inlist.toArray(new CPOperand[0]), out, opcode, str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec)
+ throws DMLRuntimeException
+ {
+ SpoofOperator op = (SpoofOperator) CodegenUtils.createInstance(_class);
+
+ //get input matrices and scalars, incl pinning of matrices
+ ArrayList<MatrixBlock> inputs = new ArrayList<MatrixBlock>();
+ ArrayList<ScalarObject> scalars = new ArrayList<ScalarObject>();
+ for (CPOperand input : _in) {
+ if(input.getDataType()==DataType.MATRIX)
+ inputs.add(ec.getMatrixInput(input.getName()));
+ else if(input.getDataType()==DataType.SCALAR)
+ scalars.add(ec.getScalarInput(input.getName(), input.getValueType(), input.isLiteral()));
+ }
+
+ // set the output dimensions to the hop node matrix dimensions
+ if( output.getDataType() == DataType.MATRIX) {
+ MatrixBlock out = new MatrixBlock();
+ op.execute(inputs, scalars, out, _numThreads);
+ ec.setMatrixOutput(output.getName(), out);
+ }
+ else if (output.getDataType() == DataType.SCALAR) {
+ ScalarObject out = op.execute(inputs, scalars, _numThreads);
+ ec.setScalarOutput(output.getName(), out);
+ }
+
+ // release input matrices
+ for (CPOperand input : _in)
+ if(input.getDataType()==DataType.MATRIX)
+ ec.releaseMatrixInput(input.getName());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
index b28e408..17d1561 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
@@ -37,7 +37,7 @@ public abstract class SPInstruction extends Instruction
CentralMoment, Covariance, QSort, QPick,
ParameterizedBuiltin, MAppend, RAppend, GAppend, GAlignedAppend, Rand,
MatrixReshape, Ternary, Quaternary, CumsumAggregate, CumsumOffset, BinUaggChain, UaggOuterChain,
- Write, INVALID,
+ Write, SpoofFused, INVALID,
Convolution
};
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
new file mode 100644
index 0000000..15b0751
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SpoofSPInstruction.java
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.spark;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.codegen.CodegenUtils;
+import org.apache.sysml.runtime.codegen.SpoofCellwise;
+import org.apache.sysml.runtime.codegen.SpoofCellwise.CellType;
+import org.apache.sysml.runtime.codegen.SpoofOperator;
+import org.apache.sysml.runtime.codegen.SpoofOuterProduct;
+import org.apache.sysml.runtime.codegen.SpoofOuterProduct.OutProdType;
+import org.apache.sysml.runtime.codegen.SpoofRowAggregate;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import org.apache.sysml.runtime.instructions.cp.DoubleObject;
+import org.apache.sysml.runtime.instructions.cp.ScalarObject;
+import org.apache.sysml.runtime.instructions.spark.SPInstruction;
+import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
+import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+
+import scala.Tuple2;
+
+public class SpoofSPInstruction extends SPInstruction
+{
+ private final Class<?> _class;
+ private final byte[] _classBytes;
+ private final CPOperand[] _in;
+ private final CPOperand _out;
+
+ public SpoofSPInstruction(Class<?> cls , byte[] classBytes, CPOperand[] in, CPOperand out, String opcode, String str) {
+ super(opcode, str);
+ _class = cls;
+ _classBytes = classBytes;
+ _sptype = SPINSTRUCTION_TYPE.SpoofFused;
+ _in = in;
+ _out = out;
+ }
+
+ public static SpoofSPInstruction parseInstruction(String str)
+ throws DMLRuntimeException
+ {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+
+ //String opcode = parts[0];
+ ArrayList<CPOperand> inlist = new ArrayList<CPOperand>();
+ Class<?> cls = CodegenUtils.loadClass(parts[1], null);
+ byte[] classBytes = CodegenUtils.getClassAsByteArray(parts[1]);
+ String opcode = parts[0] + CodegenUtils.getSpoofType(cls);
+
+ for( int i=2; i<parts.length-2; i++ )
+ inlist.add(new CPOperand(parts[i]));
+ CPOperand out = new CPOperand(parts[parts.length-2]);
+ //note: number of threads parts[parts.length-1] always ignored
+
+ return new SpoofSPInstruction(cls, classBytes, inlist.toArray(new CPOperand[0]), out, opcode, str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec)
+ throws DMLRuntimeException
+ {
+ SparkExecutionContext sec = (SparkExecutionContext)ec;
+
+ //get input rdd and variable name
+ ArrayList<String> bcVars = new ArrayList<String>();
+ MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[0].getName());
+ JavaPairRDD<MatrixIndexes,MatrixBlock> in = sec.getBinaryBlockRDDHandleForVariable( _in[0].getName() );
+ JavaPairRDD<MatrixIndexes, MatrixBlock> out = null;
+
+ //simple case: map-side only operation (one rdd input, broadcast all)
+ //keep track of broadcast variables
+ ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices = new ArrayList<PartitionedBroadcast<MatrixBlock>>();
+ ArrayList<ScalarObject> scalars = new ArrayList<ScalarObject>();
+ for( int i=1; i<_in.length; i++ ) {
+ if( _in[i].getDataType()==DataType.MATRIX) {
+ bcMatrices.add(sec.getBroadcastForVariable(_in[i].getName()));
+ bcVars.add(_in[i].getName());
+ }
+ else if(_in[i].getDataType()==DataType.SCALAR) {
+ scalars.add(sec.getScalarInput(_in[i].getName(), _in[i].getValueType(), _in[i].isLiteral()));
+ }
+ }
+
+ //initialize Spark Operator
+ if(_class.getSuperclass() == SpoofCellwise.class) // cellwise operator
+ {
+ if( _out.getDataType()==DataType.MATRIX ) {
+ SpoofOperator op = (SpoofOperator) CodegenUtils.createInstance(_class);
+
+ out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
+ if( ((SpoofCellwise)op).getCellType()==CellType.ROW_AGG && mcIn.getCols() > mcIn.getColsPerBlock() ) {
+ //NOTE: workaround with partition size needed due to potential bug in SPARK
+ //TODO investigate if some other side effect of correct blocks
+ if( out.partitions().size() > mcIn.getNumRowBlocks() )
+ out = RDDAggregateUtils.sumByKeyStable(out, (int)mcIn.getNumRowBlocks());
+ else
+ out = RDDAggregateUtils.sumByKeyStable(out);
+ }
+ sec.setRDDHandleForVariable(_out.getName(), out);
+
+ //maintain lineage information for output rdd
+ sec.addLineageRDD(_out.getName(), _in[0].getName());
+ for( String bcVar : bcVars )
+ sec.addLineageBroadcast(_out.getName(), bcVar);
+
+ //update matrix characteristics
+ updateOutputMatrixCharacteristics(sec, op);
+ }
+ else { //SCALAR
+ out = in.mapPartitionsToPair(new CellwiseFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
+ MatrixBlock tmpMB = RDDAggregateUtils.sumStable(out);
+ sec.setVariable(_out.getName(), new DoubleObject(tmpMB.getValue(0, 0)));
+ }
+ }
+ else if(_class.getSuperclass() == SpoofOuterProduct.class) // outer product operator
+ {
+ if( _out.getDataType()==DataType.MATRIX ) {
+ SpoofOperator op = (SpoofOperator) CodegenUtils.createInstance(_class);
+ OutProdType type = ((SpoofOuterProduct)op).getOuterProdType();
+
+ //update matrix characteristics
+ updateOutputMatrixCharacteristics(sec, op);
+ MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName());
+
+ out = in.mapPartitionsToPair(new OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
+ if(type == OutProdType.LEFT_OUTER_PRODUCT || type == OutProdType.RIGHT_OUTER_PRODUCT ) {
+ //NOTE: workaround with partition size needed due to potential bug in SPARK
+ //TODO investigate if some other side effect of correct blocks
+ if( in.partitions().size() > mcOut.getNumRowBlocks()*mcOut.getNumColBlocks() )
+ out = RDDAggregateUtils.sumByKeyStable( out, (int)(mcOut.getNumRowBlocks()*mcOut.getNumColBlocks()) );
+ else
+ out = RDDAggregateUtils.sumByKeyStable( out );
+ }
+ sec.setRDDHandleForVariable(_out.getName(), out);
+
+ //maintain lineage information for output rdd
+ sec.addLineageRDD(_out.getName(), _in[0].getName());
+ for( String bcVar : bcVars )
+ sec.addLineageBroadcast(_out.getName(), bcVar);
+
+ }
+ else {
+ out = in.mapPartitionsToPair(new OuterProductFunction(_class.getName(), _classBytes, bcMatrices, scalars), true);
+ MatrixBlock tmp = RDDAggregateUtils.sumStable(out);
+ sec.setVariable(_out.getName(), new DoubleObject(tmp.getValue(0, 0)));
+ }
+ }
+ else if( _class.getSuperclass() == SpoofRowAggregate.class ) { //row aggregate operator
+ RowAggregateFunction fmmc = new RowAggregateFunction(_class.getName(), _classBytes, bcMatrices, scalars);
+ JavaPairRDD<MatrixIndexes,MatrixBlock> tmpRDD = in.mapToPair(fmmc);
+ MatrixBlock tmpMB = RDDAggregateUtils.sumStable(tmpRDD);
+ sec.setMatrixOutput(_out.getName(), tmpMB);
+ return;
+ }
+ else {
+ throw new DMLRuntimeException("Operator " + _class.getSuperclass() + " is not supported on Spark");
+ }
+ }
+
+ private void updateOutputMatrixCharacteristics(SparkExecutionContext sec, SpoofOperator op)
+ throws DMLRuntimeException
+ {
+ if(op instanceof SpoofCellwise)
+ {
+ MatrixCharacteristics mcIn = sec.getMatrixCharacteristics(_in[0].getName());
+ MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName());
+ if( ((SpoofCellwise)op).getCellType()==CellType.ROW_AGG )
+ mcOut.set(mcIn.getRows(), 1, mcIn.getRowsPerBlock(), mcIn.getColsPerBlock());
+ else if( ((SpoofCellwise)op).getCellType()==CellType.NO_AGG )
+ mcOut.set(mcIn);
+ }
+ else if(op instanceof SpoofOuterProduct)
+ {
+ MatrixCharacteristics mcIn1 = sec.getMatrixCharacteristics(_in[0].getName()); //X
+ MatrixCharacteristics mcIn2 = sec.getMatrixCharacteristics(_in[1].getName()); //U
+ MatrixCharacteristics mcIn3 = sec.getMatrixCharacteristics(_in[2].getName()); //V
+ MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(_out.getName());
+ OutProdType type = ((SpoofOuterProduct)op).getOuterProdType();
+
+ if( type == OutProdType.CELLWISE_OUTER_PRODUCT)
+ mcOut.set(mcIn1.getRows(), mcIn1.getCols(), mcIn1.getRowsPerBlock(), mcIn1.getColsPerBlock());
+ else if( type == OutProdType.LEFT_OUTER_PRODUCT)
+ mcOut.set(mcIn3.getRows(), mcIn3.getCols(), mcIn3.getRowsPerBlock(), mcIn3.getColsPerBlock());
+ else if( type == OutProdType.RIGHT_OUTER_PRODUCT )
+ mcOut.set(mcIn2.getRows(), mcIn2.getCols(), mcIn2.getRowsPerBlock(), mcIn2.getColsPerBlock());
+ }
+ }
+
+ private static class RowAggregateFunction implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock>
+ {
+ private static final long serialVersionUID = -7926980450209760212L;
+
+ private ArrayList<PartitionedBroadcast<MatrixBlock>> _vectors = null;
+ private ArrayList<ScalarObject> _scalars = null;
+ private byte[] _classBytes = null;
+ private String _className = null;
+ private SpoofOperator _op = null;
+
+ public RowAggregateFunction(String className, byte[] classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars)
+ throws DMLRuntimeException
+ {
+ _className = className;
+ _classBytes = classBytes;
+ _vectors = bcMatrices;
+ _scalars = scalars;
+ }
+
+ @Override
+ public Tuple2<MatrixIndexes, MatrixBlock> call( Tuple2<MatrixIndexes, MatrixBlock> arg0 )
+ throws Exception
+ {
+ //lazy load of shipped class
+ if( _op == null ) {
+ Class<?> loadedClass = CodegenUtils.loadClass(_className, _classBytes);
+ _op = (SpoofOperator) CodegenUtils.createInstance(loadedClass);
+ }
+
+ //get main input block and indexes
+ MatrixIndexes ixIn = arg0._1();
+ MatrixBlock blkIn = arg0._2();
+ int rowIx = (int)ixIn.getRowIndex();
+
+ //prepare output and execute single-threaded operator
+ ArrayList<MatrixBlock> inputs = getVectorInputsFromBroadcast(blkIn, rowIx);
+ MatrixIndexes ixOut = new MatrixIndexes(1,1);
+ MatrixBlock blkOut = new MatrixBlock();
+ _op.execute(inputs, _scalars, blkOut);
+
+ //output new tuple
+ return new Tuple2<MatrixIndexes, MatrixBlock>(ixOut, blkOut);
+ }
+
+ private ArrayList<MatrixBlock> getVectorInputsFromBroadcast(MatrixBlock blkIn, int rowIndex)
+ throws DMLRuntimeException
+ {
+ ArrayList<MatrixBlock> ret = new ArrayList<MatrixBlock>();
+ ret.add(blkIn);
+ for( PartitionedBroadcast<MatrixBlock> vector : _vectors )
+ ret.add(vector.getBlock((vector.getNumRowBlocks()>=rowIndex)?rowIndex:1, 1));
+ return ret;
+ }
+ }
+
+ private static class CellwiseFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock>
+ {
+ private static final long serialVersionUID = -8209188316939435099L;
+
+ private ArrayList<PartitionedBroadcast<MatrixBlock>> _vectors = null;
+ private ArrayList<ScalarObject> _scalars = null;
+ private byte[] _classBytes = null;
+ private String _className = null;
+ private SpoofOperator _op = null;
+
+ public CellwiseFunction(String className, byte[] classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars)
+ throws DMLRuntimeException
+ {
+ _className = className;
+ _classBytes = classBytes;
+ _vectors = bcMatrices;
+ _scalars = scalars;
+ }
+
+ @Override
+ public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg)
+ throws Exception
+ {
+ //lazy load of shipped class
+ if( _op == null ) {
+ Class<?> loadedClass = CodegenUtils.loadClass(_className, _classBytes);
+ _op = (SpoofOperator) CodegenUtils.createInstance(loadedClass);
+ }
+
+ List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>();
+ while(arg.hasNext())
+ {
+ Tuple2<MatrixIndexes,MatrixBlock> tmp = arg.next();
+ MatrixIndexes ixIn = tmp._1();
+ MatrixBlock blkIn = tmp._2();
+ MatrixIndexes ixOut = ixIn;
+ MatrixBlock blkOut = new MatrixBlock();
+ ArrayList<MatrixBlock> inputs = getVectorInputsFromBroadcast(blkIn, (int)ixIn.getRowIndex());
+
+ //execute core operation
+ if(((SpoofCellwise)_op).getCellType()==CellType.FULL_AGG) {
+ ScalarObject obj = _op.execute(inputs, _scalars, 1);
+ blkOut.reset(1, 1);
+ blkOut.quickSetValue(0, 0, obj.getDoubleValue());
+ }
+ else {
+ if(((SpoofCellwise)_op).getCellType()==CellType.ROW_AGG)
+ ixOut = new MatrixIndexes(ixOut.getRowIndex(), 1);
+ _op.execute(inputs, _scalars, blkOut);
+ }
+ ret.add(new Tuple2<MatrixIndexes,MatrixBlock>(ixOut, blkOut));
+ }
+ return ret.iterator();
+ }
+
+ private ArrayList<MatrixBlock> getVectorInputsFromBroadcast(MatrixBlock blkIn, int rowIndex)
+ throws DMLRuntimeException
+ {
+ ArrayList<MatrixBlock> ret = new ArrayList<MatrixBlock>();
+ ret.add(blkIn);
+ for( PartitionedBroadcast<MatrixBlock> vector : _vectors )
+ ret.add(vector.getBlock((vector.getNumRowBlocks()>=rowIndex)?rowIndex:1, 1));
+ return ret;
+ }
+ }
+
+ private static class OuterProductFunction implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock>
+ {
+ private static final long serialVersionUID = -8209188316939435099L;
+
+ private ArrayList<PartitionedBroadcast<MatrixBlock>> _bcMatrices = null;
+ private ArrayList<ScalarObject> _scalars = null;
+ private byte[] _classBytes = null;
+ private String _className = null;
+ private SpoofOperator _op = null;
+
+ public OuterProductFunction(String className, byte[] classBytes, ArrayList<PartitionedBroadcast<MatrixBlock>> bcMatrices, ArrayList<ScalarObject> scalars)
+ throws DMLRuntimeException
+ {
+ _className = className;
+ _classBytes = classBytes;
+ _bcMatrices = bcMatrices;
+ _scalars = scalars;
+ }
+
+ @Override
+ public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg)
+ throws Exception
+ {
+ //lazy load of shipped class
+ if( _op == null ) {
+ Class<?> loadedClass = CodegenUtils.loadClass(_className, _classBytes);
+ _op = (SpoofOperator) CodegenUtils.createInstance(loadedClass);
+ }
+
+ List<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>();
+ while(arg.hasNext())
+ {
+ Tuple2<MatrixIndexes,MatrixBlock> tmp = arg.next();
+ MatrixIndexes ixIn = tmp._1();
+ MatrixBlock blkIn = tmp._2();
+ MatrixBlock blkOut = new MatrixBlock();
+
+ ArrayList<MatrixBlock> inputs = new ArrayList<MatrixBlock>();
+ inputs.add(blkIn);
+ inputs.add(_bcMatrices.get(0).getBlock((int)ixIn.getRowIndex(), 1)); // U
+ inputs.add(_bcMatrices.get(1).getBlock((int)ixIn.getColumnIndex(), 1)); // V
+
+ //execute core operation
+ if(((SpoofOuterProduct)_op).getOuterProdType()==OutProdType.AGG_OUTER_PRODUCT) {
+ ScalarObject obj = _op.execute(inputs, _scalars,1);
+ blkOut.reset(1, 1);
+ blkOut.quickSetValue(0, 0, obj.getDoubleValue());
+ }
+ else {
+ _op.execute(inputs, _scalars, blkOut);
+ }
+
+ ret.add(new Tuple2<MatrixIndexes,MatrixBlock>(createOutputIndexes(ixIn,_op), blkOut));
+ }
+
+ return ret.iterator();
+ }
+
+ private MatrixIndexes createOutputIndexes(MatrixIndexes in, SpoofOperator spoofOp) {
+ if( ((SpoofOuterProduct)spoofOp).getOuterProdType() == OutProdType.LEFT_OUTER_PRODUCT )
+ return new MatrixIndexes(in.getColumnIndex(), 1);
+ else if ( ((SpoofOuterProduct)spoofOp).getOuterProdType() == OutProdType.RIGHT_OUTER_PRODUCT)
+ return new MatrixIndexes(in.getRowIndex(), 1);
+ else
+ return in;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/982ecb1a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
index 61c950a..2dfff74 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDAggregateUtils.java
@@ -69,13 +69,17 @@ public class RDDAggregateUtils
}
}
- public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable( JavaPairRDD<MatrixIndexes, MatrixBlock> in )
+ public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable( JavaPairRDD<MatrixIndexes, MatrixBlock> in ) {
+ return sumByKeyStable(in, in.getNumPartitions());
+ }
+
+ public static JavaPairRDD<MatrixIndexes, MatrixBlock> sumByKeyStable( JavaPairRDD<MatrixIndexes, MatrixBlock> in, int numPartitions )
{
//stable sum of blocks per key, by passing correction blocks along with aggregates
JavaPairRDD<MatrixIndexes, CorrMatrixBlock> tmp =
in.combineByKey( new CreateCorrBlockCombinerFunction(),
new MergeSumBlockValueFunction(),
- new MergeSumBlockCombinerFunction() );
+ new MergeSumBlockCombinerFunction(), numPartitions );
//strip-off correction blocks from
JavaPairRDD<MatrixIndexes, MatrixBlock> out =