You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/07/30 21:34:55 UTC
[3/3] systemml git commit: [SYSTEMML-2470] Fix distributed spark
cumsumprod (aggregate 1st pass)
[SYSTEMML-2470] Fix distributed spark cumsumprod (aggregate 1st pass)
This patch fixes result correctness issues of distributed spark
operations of the new cumulative aggregate cumsumprod. In detail, we now
use cumsumprod(AB)[n] instead of sum(AB) as aggregation function during
the forward pass of the generic two-pass algorithm.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/50ddddb9
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/50ddddb9
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/50ddddb9
Branch: refs/heads/master
Commit: 50ddddb90b28c6e28e97195dded9696edcdc3b45
Parents: 252e498
Author: Matthias Boehm <mb...@gmail.com>
Authored: Mon Jul 30 14:34:49 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Mon Jul 30 14:36:19 2018 -0700
----------------------------------------------------------------------
.../spark/CumulativeAggregateSPInstruction.java | 17 +++++++++++------
.../functions/unary/matrix/FullCumsumprodTest.java | 3 +--
2 files changed, 12 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/50ddddb9/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
index 74390e1..8514acc 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
@@ -27,6 +27,7 @@ import scala.Tuple2;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.functionobjects.Builtin;
import org.apache.sysml.runtime.functionobjects.PlusMultiply;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
@@ -36,6 +37,7 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysml.runtime.matrix.operators.UnaryOperator;
public class CumulativeAggregateSPInstruction extends AggregateUnarySPInstruction {
@@ -79,10 +81,11 @@ public class CumulativeAggregateSPInstruction extends AggregateUnarySPInstructio
{
private static final long serialVersionUID = 11324676268945117L;
- private AggregateUnaryOperator _op = null;
- private long _rlen = -1;
- private int _brlen = -1;
- private int _bclen = -1;
+ private final AggregateUnaryOperator _op;
+ private UnaryOperator _uop = null;
+ private final long _rlen;
+ private final int _brlen;
+ private final int _bclen;
public RDDCumAggFunction( AggregateUnaryOperator op, long rlen, int brlen, int bclen ) {
_op = op;
@@ -105,10 +108,12 @@ public class CumulativeAggregateSPInstruction extends AggregateUnarySPInstructio
AggregateUnaryOperator aop = (AggregateUnaryOperator)_op;
if( aop.aggOp.increOp.fn instanceof PlusMultiply ) { //cumsumprod
aop.indexFn.execute(ixIn, ixOut);
- MatrixBlock t1 = blkIn.slice(0, blkIn.getNumRows()-1, 0, 0, new MatrixBlock());
+ if( _uop == null )
+ _uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+*"));
+ MatrixBlock t1 = (MatrixBlock) blkIn.unaryOperations(_uop, new MatrixBlock());
MatrixBlock t2 = blkIn.slice(0, blkIn.getNumRows()-1, 1, 1, new MatrixBlock());
blkOut.reset(1, 2);
- blkOut.quickSetValue(0, 0, t1.sum());
+ blkOut.quickSetValue(0, 0, t1.quickGetValue(t1.getNumRows()-1, 0));
blkOut.quickSetValue(0, 1, t2.prod());
}
else { //general case
http://git-wip-us.apache.org/repos/asf/systemml/blob/50ddddb9/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java
index 7f02055..f13e765 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/unary/matrix/FullCumsumprodTest.java
@@ -111,8 +111,7 @@ public class FullCumsumprodTest extends AutomatedTestBase
String.valueOf(reverse).toUpperCase(), output("C") };
double[][] A = getRandomMatrix(rows, 1, -10, 10, sparsity, 3);
- double[][] B = getRandomMatrix(rows, 1, -1, 1, 0.1, 7);
- //FIXME double[][] B = getRandomMatrix(rows, 1, -1, 1, 0.9, 7);
+ double[][] B = getRandomMatrix(rows, 1, -1, 1, 0.9, 7);
writeInputMatrixWithMTD("A", A, false);
writeInputMatrixWithMTD("B", B, false);