You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2015/12/19 19:27:59 UTC
[1/4] incubator-systemml git commit: Fix spark grouped aggregate
(wrong group-value join), incl cleanup
Repository: incubator-systemml
Updated Branches:
refs/heads/master 4124f196a -> 7290510ec
Fix spark grouped aggregate (wrong group-value join), incl cleanup
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/635907d4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/635907d4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/635907d4
Branch: refs/heads/master
Commit: 635907d4ba388e7873cfe74f8fe5c7066e695f6c
Parents: 4124f19
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Mon Dec 14 07:38:29 2015 -0800
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Fri Dec 18 22:44:31 2015 +0100
----------------------------------------------------------------------
.../ParameterizedBuiltinSPInstruction.java | 10 +-
.../spark/functions/ExtractGroup.java | 51 ++---------
.../spark/functions/ExtractGroupNWeights.java | 96 ++++----------------
.../UnflattenIterablesAfterCogroup.java | 58 ------------
4 files changed, 33 insertions(+), 182 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/635907d4/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index e48b83f..84b4821 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -50,7 +50,6 @@ import org.apache.sysml.runtime.instructions.spark.functions.ExtractGroupNWeight
import org.apache.sysml.runtime.instructions.spark.functions.PerformGroupByAggInCombiner;
import org.apache.sysml.runtime.instructions.spark.functions.PerformGroupByAggInReducer;
import org.apache.sysml.runtime.instructions.spark.functions.ReplicateVectorFunction;
-import org.apache.sysml.runtime.instructions.spark.functions.UnflattenIterablesAfterCogroup;
import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -202,15 +201,12 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
throw new DMLRuntimeException("Grouped Aggregate SPInstruction is not supported for dimension of target != weights");
}
- groupWeightedCells = groups.cogroup(target)
- .mapToPair(new UnflattenIterablesAfterCogroup())
- .cogroup(weights)
+ groupWeightedCells = groups.join(target).join(weights)
.flatMapToPair(new ExtractGroupNWeights());
}
else {
- groupWeightedCells = groups.cogroup(target)
- .mapToPair(new UnflattenIterablesAfterCogroup())
- .flatMapToPair(new ExtractGroup());
+ groupWeightedCells = groups.join(target)
+ .flatMapToPair(new ExtractGroup());
}
// Step 2: Make sure we have brlen required while creating <MatrixIndexes, MatrixCell>
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/635907d4/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
index 6bde0da..fcd0166 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
@@ -24,11 +24,8 @@ import java.util.ArrayList;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import scala.Tuple2;
-
-import org.apache.sysml.runtime.matrix.data.IJV;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
-import org.apache.sysml.runtime.matrix.data.SparseRowsIterator;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
import org.apache.sysml.runtime.util.UtilFunctions;
@@ -39,28 +36,22 @@ public class ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes,T
@Override
public Iterable<Tuple2<Long, WeightedCell>> call(
Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg)
- throws Exception {
+ throws Exception
+ {
MatrixBlock group = arg._2._1;
MatrixBlock target = arg._2._2;
- ArrayList<Double> groupIDs = getColumn(group);
- ArrayList<Double> values = getColumn(target);
- ArrayList<Tuple2<Long, WeightedCell>> groupValuePairs = new ArrayList<Tuple2<Long, WeightedCell>>();
-
- if(groupIDs.size() != values.size()) {
- throw new Exception("The blocksize for group and target block are mismatched: " + groupIDs.size() + " != " + values.size());
+ //sanity check matching block dimensions
+ if(group.getNumRows() != target.getNumRows()) {
+ throw new Exception("The blocksize for group and target blocks are mismatched: " + group.getNumRows() + " != " + target.getNumRows());
}
- for(int i = 0; i < groupIDs.size(); i++) {
+ //output weighted cells
+ ArrayList<Tuple2<Long, WeightedCell>> groupValuePairs = new ArrayList<Tuple2<Long, WeightedCell>>();
+ for(int i = 0; i < group.getNumRows(); i++) {
WeightedCell weightedCell = new WeightedCell();
- try {
- weightedCell.setValue(values.get(i));
- }
- catch(Exception e) {
- weightedCell.setValue(0);
- }
- weightedCell.setWeight(1.0);
- long groupVal = UtilFunctions.toLong(groupIDs.get(i));
+ weightedCell.setValue(target.quickGetValue(i, 0));
+ long groupVal = UtilFunctions.toLong(group.quickGetValue(i, 0));
if(groupVal < 1) {
throw new Exception("Expected group values to be greater than equal to 1 but found " + groupVal);
}
@@ -68,26 +59,4 @@ public class ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes,T
}
return groupValuePairs;
}
-
-
- public ArrayList<Double> getColumn(MatrixBlock blk) throws Exception {
- ArrayList<Double> retVal = new ArrayList<Double>();
- if(blk != null) {
- if (blk.isInSparseFormat()) {
- SparseRowsIterator iter = blk.getSparseRowsIterator();
- while( iter.hasNext() ) {
- IJV cell = iter.next();
- retVal.add(cell.v);
- }
- }
- else {
- double[] valuesInBlock = blk.getDenseArray();
- for(int i = 0; i < valuesInBlock.length; i++) {
- retVal.add(valuesInBlock[i]);
- }
- }
- }
- return retVal;
- }
-
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/635907d4/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroupNWeights.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroupNWeights.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroupNWeights.java
index e00f28e..17c58c5 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroupNWeights.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroupNWeights.java
@@ -25,98 +25,42 @@ import org.apache.spark.api.java.function.PairFlatMapFunction;
import scala.Tuple2;
-import org.apache.sysml.runtime.matrix.data.IJV;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
-import org.apache.sysml.runtime.matrix.data.SparseRowsIterator;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
import org.apache.sysml.runtime.util.UtilFunctions;
-public class ExtractGroupNWeights implements PairFlatMapFunction<Tuple2<MatrixIndexes,Tuple2<Iterable<Tuple2<MatrixBlock,MatrixBlock>>,Iterable<MatrixBlock>>>, Long, WeightedCell> {
+public class ExtractGroupNWeights implements PairFlatMapFunction<Tuple2<MatrixIndexes,Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>>, Long, WeightedCell> {
private static final long serialVersionUID = -188180042997588072L;
@Override
public Iterable<Tuple2<Long, WeightedCell>> call(
- Tuple2<MatrixIndexes, Tuple2<Iterable<Tuple2<MatrixBlock, MatrixBlock>>, Iterable<MatrixBlock>>> arg)
- throws Exception {
- MatrixBlock group = null;
- MatrixBlock target = null;
- for(Tuple2<MatrixBlock, MatrixBlock> kv : arg._2._1) {
- if(group == null) {
- group = kv._1;
- target = kv._2;
- }
- else {
- throw new Exception("More than 1 block with same MatrixIndexes");
- }
- }
- MatrixBlock weight = null;
- for(MatrixBlock blk : arg._2._2) {
- if(weight == null) {
- weight = blk;
- }
- else {
- throw new Exception("More than 1 block with same MatrixIndexes");
- }
+ Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>> arg)
+ throws Exception
+ {
+ MatrixBlock group = arg._2._1._1;
+ MatrixBlock target = arg._2._1._2;
+ MatrixBlock weight = arg._2._2;
+
+ //sanity check matching block dimensions
+ if(group.getNumRows() != target.getNumRows() || group.getNumRows()!=target.getNumRows()) {
+ throw new Exception("The blocksize for group/target/weight blocks are mismatched: " + group.getNumRows() + ", " + target.getNumRows() + ", " + weight.getNumRows());
}
- ArrayList<Double> groupIDs = getColumn(group);
- ArrayList<Double> values = getColumn(target);
- ArrayList<Double> w = getColumn(weight);
+ //output weighted cells
ArrayList<Tuple2<Long, WeightedCell>> groupValuePairs = new ArrayList<Tuple2<Long, WeightedCell>>();
-
- if(groupIDs != null) {
- if(groupIDs.size() != values.size() || groupIDs.size() != w.size()) {
- throw new Exception("The blocksize for group, weight and target block are mismatched: "
- + groupIDs.size() + " != " + values.size() + " || " + groupIDs.size() + " != " + w.size());
+ for(int i = 0; i < group.getNumRows(); i++) {
+ WeightedCell weightedCell = new WeightedCell();
+ weightedCell.setValue(target.quickGetValue(i, 0));
+ weightedCell.setWeight(weight.quickGetValue(i, 0));
+ long groupVal = UtilFunctions.toLong(group.quickGetValue(i, 0));
+ if(groupVal < 1) {
+ throw new Exception("Expected group values to be greater than equal to 1 but found " + groupVal);
}
- for(int i = 0; i < groupIDs.size(); i++) {
- WeightedCell weightedCell = new WeightedCell();
- try {
- weightedCell.setValue(values.get(i));
- }
- catch(Exception e) {
- weightedCell.setValue(0);
- }
- try {
- weightedCell.setWeight(w.get(i));
- }
- catch(Exception e) {
- weightedCell.setValue(1);
- }
- long groupVal = UtilFunctions.toLong(groupIDs.get(i));
- if(groupVal < 1) {
- throw new Exception("Expected group values to be greater than equal to 1 but found " + groupVal);
- }
- groupValuePairs.add(new Tuple2<Long, WeightedCell>(groupVal, weightedCell));
- }
- }
- else {
- throw new Exception("group ids block shouldn't be empty");
+ groupValuePairs.add(new Tuple2<Long, WeightedCell>(groupVal, weightedCell));
}
return groupValuePairs;
}
-
- public ArrayList<Double> getColumn(MatrixBlock blk) throws Exception {
- ArrayList<Double> retVal = new ArrayList<Double>();
- if(blk != null) {
- if (blk.isInSparseFormat()) {
- SparseRowsIterator iter = blk.getSparseRowsIterator();
- while( iter.hasNext() ) {
- IJV cell = iter.next();
- retVal.add(cell.v);
- }
- }
- else {
- double[] valuesInBlock = blk.getDenseArray();
- for(int i = 0; i < valuesInBlock.length; i++) {
- retVal.add(valuesInBlock[i]);
- }
- }
- }
- return retVal;
- }
-
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/635907d4/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/UnflattenIterablesAfterCogroup.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/UnflattenIterablesAfterCogroup.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/UnflattenIterablesAfterCogroup.java
deleted file mode 100644
index 5bb5a96..0000000
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/UnflattenIterablesAfterCogroup.java
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.instructions.spark.functions;
-
-import org.apache.spark.api.java.function.PairFunction;
-
-import scala.Tuple2;
-
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
-
-public class UnflattenIterablesAfterCogroup implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Iterable<MatrixBlock>,Iterable<MatrixBlock>>>, MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> {
-
- private static final long serialVersionUID = 5367350062892272775L;
-
- @Override
- public Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> call(
- Tuple2<MatrixIndexes, Tuple2<Iterable<MatrixBlock>, Iterable<MatrixBlock>>> arg)
- throws Exception {
- MatrixBlock left = null;
- MatrixBlock right = null;
- for(MatrixBlock blk : arg._2._1) {
- if(left == null) {
- left = blk;
- }
- else {
- throw new Exception("More than 1 block with same MatrixIndexes");
- }
- }
- for(MatrixBlock blk : arg._2._2) {
- if(right == null) {
- right = blk;
- }
- else {
- throw new Exception("More than 1 block with same MatrixIndexes");
- }
- }
- return new Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>>(arg._1, new Tuple2<MatrixBlock, MatrixBlock>(left, right));
- }
-
-}
\ No newline at end of file
[3/4] incubator-systemml git commit: Fix table-rexpand simplification
rewrite (apply for table w/o weights)
Posted by mb...@apache.org.
Fix table-rexpand simplification rewrite (apply for table w/o weights)
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/112ba90c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/112ba90c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/112ba90c
Branch: refs/heads/master
Commit: 112ba90c891cce418bde2d54e277f055efff39be
Parents: f73569b
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sat Dec 19 19:25:50 2015 +0100
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sat Dec 19 19:25:50 2015 +0100
----------------------------------------------------------------------
.../rewrite/RewriteAlgebraicSimplificationStatic.java | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/112ba90c/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 58c2f6b..adda6a2 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -1524,13 +1524,15 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
//pattern: table(seq(1,nrow(v)), v, nrow(v), m) -> rexpand(v, max=m, dir=row, ignore=false, cast=true)
//note: this rewrite supports both left/right sequence
- if( hi instanceof TernaryOp && hi.getInput().size()==5
- && hi.getInput().get(3) instanceof LiteralOp && hi.getInput().get(4) instanceof LiteralOp )
+ if( hi instanceof TernaryOp && hi.getInput().size()==5 //table without weights
+ && hi.getInput().get(2) instanceof LiteralOp
+ && HopRewriteUtils.getDoubleValue((LiteralOp)hi.getInput().get(2))==1
+ && hi.getInput().get(3) instanceof LiteralOp && hi.getInput().get(4) instanceof LiteralOp)
{
if( (HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) &&
- hi.getInput().get(3) instanceof LiteralOp) //pattern a: table(seq(1,nrow(v)), v, nrow(v), m)
+ hi.getInput().get(4) instanceof LiteralOp) //pattern a: table(seq(1,nrow(v)), v, nrow(v), m)
||(HopRewriteUtils.isBasic1NSequence(hi.getInput().get(1)) &&
- hi.getInput().get(2) instanceof LiteralOp) ) //pattern b: table(v, seq(1,nrow(v)), m, nrow(v))
+ hi.getInput().get(3) instanceof LiteralOp) ) //pattern b: table(v, seq(1,nrow(v)), m, nrow(v))
{
//determine variable parameters for pattern a/b
int ixTgt = HopRewriteUtils.isBasic1NSequence(hi.getInput().get(0)) ? 1 : 0;
[2/4] incubator-systemml git commit: New grouped aggregate over
matrices (generalized cp/mr/sp), incl tests
Posted by mb...@apache.org.
New grouped aggregate over matrices (generalized cp/mr/sp), incl tests
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/f73569b4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/f73569b4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/f73569b4
Branch: refs/heads/master
Commit: f73569b4a8fa0c8e9e10f3e022e4920ce1fd7b37
Parents: 635907d
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Fri Dec 18 22:44:12 2015 +0100
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Fri Dec 18 22:44:39 2015 +0100
----------------------------------------------------------------------
.../sysml/hops/ParameterizedBuiltinOp.java | 75 +++-
.../RewriteAlgebraicSimplificationStatic.java | 3 +-
.../org/apache/sysml/lops/GroupedAggregate.java | 73 ++--
.../ParameterizedBuiltinFunctionExpression.java | 111 ++---
.../mr/GroupedAggregateInstruction.java | 24 +-
.../ParameterizedBuiltinSPInstruction.java | 62 +--
.../spark/functions/ExtractGroup.java | 24 +-
.../spark/functions/ExtractGroupNWeights.java | 9 +-
.../functions/PerformGroupByAggInReducer.java | 34 +-
.../sysml/runtime/matrix/GroupedAggMR.java | 5 +-
.../sysml/runtime/matrix/data/MatrixBlock.java | 181 +++++++--
.../matrix/data/TaggedMatrixIndexes.java | 99 +++++
.../matrix/mapred/GroupedAggMRCombiner.java | 12 +-
.../matrix/mapred/GroupedAggMRMapper.java | 34 +-
.../matrix/mapred/GroupedAggMRReducer.java | 9 +-
.../matrix/mapred/MRJobConfiguration.java | 21 +-
.../FullGroupedAggregateMatrixTest.java | 402 +++++++++++++++++++
.../aggregate/GroupedAggregateMatrix.R | 70 ++++
.../aggregate/GroupedAggregateMatrix.dml | 51 +++
.../aggregate/GroupedAggregateMatrixNoDims.dml | 51 +++
20 files changed, 1091 insertions(+), 259 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
index 89f6959..3a0a4f5 100644
--- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
@@ -24,6 +24,7 @@ import java.util.Map.Entry;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.lops.Aggregate;
+import org.apache.sysml.lops.AppendR;
import org.apache.sysml.lops.Data;
import org.apache.sysml.lops.Group;
import org.apache.sysml.lops.GroupedAggregate;
@@ -237,21 +238,53 @@ public class ParameterizedBuiltinOp extends Hop
}
else
{
- Lop append = BinaryOp.constructMRAppendLop(
- getInput().get(_paramIndexMap.get(Statement.GAGG_TARGET)),
- getInput().get(_paramIndexMap.get(Statement.GAGG_GROUPS)),
+ Hop target = getInput().get(_paramIndexMap.get(Statement.GAGG_TARGET));
+ Hop groups = getInput().get(_paramIndexMap.get(Statement.GAGG_GROUPS));
+ Lop append = null;
+
+ if( target.getDim2()>=target.getColsInBlock() ) //multi-column-block result matrix
+ {
+ long m1_dim1 = target.getDim1();
+ long m1_dim2 = target.getDim2();
+ long m2_dim1 = groups.getDim1();
+ long m2_dim2 = groups.getDim2();
+ long m3_dim1 = m1_dim1;
+ long m3_dim2 = ((m1_dim2>0 && m2_dim2>0) ? (m1_dim2 + m2_dim2) : -1);
+ long m3_nnz = (target.getNnz()>0 && groups.getNnz()>0) ? (target.getNnz() + groups.getNnz()) : -1;
+ long brlen = target.getRowsInBlock();
+ long bclen = target.getColsInBlock();
+
+ Lop offset = createOffsetLop(target, true);
+ Lop rep = new RepMat(groups.constructLops(), offset, true, groups.getDataType(), groups.getValueType());
+ setOutputDimensions(rep);
+ setLineNumbers(rep);
+
+ Group group1 = new Group(target.constructLops(), Group.OperationTypes.Sort, DataType.MATRIX, target.getValueType());
+ group1.getOutputParameters().setDimensions(m1_dim1, m1_dim2, brlen, bclen, target.getNnz());
+ setLineNumbers(group1);
+
+ Group group2 = new Group(rep, Group.OperationTypes.Sort, DataType.MATRIX, groups.getValueType());
+ group1.getOutputParameters().setDimensions(m2_dim1, m2_dim2, brlen, bclen, groups.getNnz());
+ setLineNumbers(group2);
+
+ append = new AppendR(group1, group2, DataType.MATRIX, ValueType.DOUBLE, true, ExecType.MR);
+ append.getOutputParameters().setDimensions(m3_dim1, m3_dim2, brlen, bclen, m3_nnz);
+ setLineNumbers(append);
+ }
+ else //single-column-block vector or matrix
+ {
+ append = BinaryOp.constructMRAppendLop(
+ target, groups,
DataType.MATRIX, getValueType(), true,
getInput().get(_paramIndexMap.get(Statement.GAGG_TARGET)));
+ }
- // add the combine lop to parameter list, with a new name
- // "combinedinput"
+ // add the combine lop to parameter list, with a new name "combinedinput"
inputlops.put(GroupedAggregate.COMBINEDINPUT, append);
inputlops.remove(Statement.GAGG_TARGET);
inputlops.remove(Statement.GAGG_GROUPS);
-
}
- int colwise = -1;
long outputDim1=-1, outputDim2=-1;
Lop numGroups = inputlops.get(Statement.GAGG_NUM_GROUPS);
if ( !dimsKnown() && numGroups != null && numGroups instanceof Data && ((Data)numGroups).isLiteral() ) {
@@ -260,25 +293,20 @@ public class ParameterizedBuiltinOp extends Hop
Lop input = inputlops.get(GroupedAggregate.COMBINEDINPUT);
long inDim1 = input.getOutputParameters().getNumRows();
long inDim2 = input.getOutputParameters().getNumCols();
- if(inDim1 > 0 && inDim2 > 0 ) {
- if ( inDim1 > inDim2 )
- colwise = 1;
- else
- colwise = 0;
- }
+ boolean rowwise = (inDim1==1 && inDim2 > 1 );
- if ( colwise == 1 ) {
+ if( rowwise ) { //vector
outputDim1 = ngroups;
outputDim2 = 1;
}
- else if ( colwise == 0 ) {
- outputDim1 = 1;
+ else { //vector or matrix
+ outputDim1 = inDim2;
outputDim2 = ngroups;
}
}
- GroupedAggregate grp_agg = new GroupedAggregate(inputlops, getDataType(), getValueType());
+ GroupedAggregate grp_agg = new GroupedAggregate(inputlops, isWeighted, getDataType(), getValueType());
// output dimensions are unknown at compilation time
grp_agg.getOutputParameters().setDimensions(outputDim1, outputDim2, getRowsInBlock(), getColsInBlock(), -1);
@@ -774,7 +802,8 @@ public class ParameterizedBuiltinOp extends Hop
Hop ngroups = getInput().get(_paramIndexMap.get(Statement.GAGG_NUM_GROUPS));
if(ngroups != null && ngroups instanceof LiteralOp) {
long m = HopRewriteUtils.getIntValueSafe((LiteralOp)ngroups);
- return new long[]{m,1,m};
+ long n = (mc.getRows()==1)?1:mc.getCols();
+ return new long[]{m, n, m};
}
}
@@ -782,9 +811,10 @@ public class ParameterizedBuiltinOp extends Hop
// #groups = #rows in the grouping attribute (e.g., categorical attribute is an ID column, say EmployeeID).
// In such a case, #rows in the output = #rows in the input. Also, output sparsity is
// likely to be 1.0 (e.g., groupedAgg(groups=<a ID column>, fn="count"))
- long m = mc.getRows();
+ long m = mc.getRows();
+ long n = (mc.getRows()==1)?1:mc.getCols();
if ( m >= 1 ) {
- ret = new long[]{m, 1, m};
+ ret = new long[]{m, n, m};
}
}
else if( _op == ParamBuiltinOp.RMEMPTY )
@@ -923,8 +953,11 @@ public class ParameterizedBuiltinOp extends Hop
}
}
+ Hop target = getInput().get(_paramIndexMap.get(Statement.GAGG_TARGET));
+ long ldim2 = (target.getDim1()==1)?1:target.getDim2();
+
setDim1( ldim1 );
- setDim2( 1 );
+ setDim2( ldim2 );
break;
}
case RMEMPTY: {
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 9dc1380..58c2f6b 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -1291,7 +1291,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
{
ParameterizedBuiltinOp phi = (ParameterizedBuiltinOp)hi;
- if( phi.isCountFunction() ) //aggregate(fn="count")
+ if( phi.isCountFunction() //aggregate(fn="count")
+ && phi.getTargetHop().getDim2()==1 ) //only for vector
{
HashMap<String, Integer> params = phi.getParamIndexMap();
int ix1 = params.get(Statement.GAGG_TARGET);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/lops/GroupedAggregate.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/GroupedAggregate.java b/src/main/java/org/apache/sysml/lops/GroupedAggregate.java
index a30ad4d..9498e94 100644
--- a/src/main/java/org/apache/sysml/lops/GroupedAggregate.java
+++ b/src/main/java/org/apache/sysml/lops/GroupedAggregate.java
@@ -34,19 +34,41 @@ import org.apache.sysml.parser.Expression.*;
*
*/
public class GroupedAggregate extends Lop
-{
-
+{
private HashMap<String, Lop> _inputParams;
private static final String opcode = "groupedagg";
public static final String COMBINEDINPUT = "combinedinput";
+ private boolean _weights = false;
+
/**
* Constructor to perform grouped aggregate.
* inputParameterLops <- parameters required to compute different aggregates (hashmap)
* "combinedinput" -- actual data
* "function" -- aggregate function
*/
+
+ public GroupedAggregate(
+ HashMap<String, Lop> inputParameterLops, boolean weights,
+ DataType dt, ValueType vt) {
+ this(inputParameterLops, dt, vt, ExecType.MR);
+ _weights = weights;
+ }
+ public GroupedAggregate(
+ HashMap<String, Lop> inputParameterLops,
+ DataType dt, ValueType vt, ExecType et) {
+ super(Lop.Type.GroupedAgg, dt, vt);
+ init(inputParameterLops, dt, vt, et);
+ }
+
+ /**
+ *
+ * @param inputParameterLops
+ * @param dt
+ * @param vt
+ * @param et
+ */
private void init(HashMap<String, Lop> inputParameterLops,
DataType dt, ValueType vt, ExecType et) {
if ( et == ExecType.MR ) {
@@ -104,19 +126,6 @@ public class GroupedAggregate extends Lop
this.lps.setProperties(inputs, et, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob);
}
}
-
- public GroupedAggregate(
- HashMap<String, Lop> inputParameterLops,
- DataType dt, ValueType vt) {
- this(inputParameterLops, dt, vt, ExecType.MR);
- }
-
- public GroupedAggregate(
- HashMap<String, Lop> inputParameterLops,
- DataType dt, ValueType vt, ExecType et) {
- super(Lop.Type.GroupedAgg, dt, vt);
- init(inputParameterLops, dt, vt, et);
- }
@Override
public String toString() {
@@ -184,33 +193,6 @@ public class GroupedAggregate extends Lop
return sb.toString();
}
-
- /*@Override
- public String getInstructions(String input1, String input2, String output)
- {
- StringBuilder sb = new StringBuilder();
- sb.append( getExecType() );
- sb.append( Lop.OPERAND_DELIMITOR );
- sb.append( "groupedagg" );
- sb.append( OPERAND_DELIMITOR );
- sb.append( input1 );
- sb.append( DATATYPE_PREFIX );
- sb.append( getInputs().get(0).getDataType() );
- sb.append( VALUETYPE_PREFIX );
- sb.append( getInputs().get(0).getValueType() );
- sb.append( input2 );
- sb.append( DATATYPE_PREFIX );
- sb.append( getInputs().get(1).getDataType() );
- sb.append( VALUETYPE_PREFIX );
- sb.append( getInputs().get(1).getValueType() );
- sb.append( output );
- sb.append( DATATYPE_PREFIX );
- sb.append( getDataType() );
- sb.append( VALUETYPE_PREFIX );
- sb.append( getValueType() );
-
- return sb.toString();
- }*/
@Override
public String getInstructions(int input_index, int output_index)
@@ -237,8 +219,11 @@ public class GroupedAggregate extends Lop
// add output_index to instruction
sb.append( OPERAND_DELIMITOR );
- sb.append( this.prepOutputOperand(output_index));
-
+ sb.append( prepOutputOperand(output_index) );
+
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( _weights );
+
return sb.toString();
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
index 326c4ae..b0fe09d 100644
--- a/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/ParameterizedBuiltinFunctionExpression.java
@@ -403,59 +403,76 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
}
- private void validateGroupedAgg(DataIdentifier output, boolean conditional) throws LanguageException {
- int colwise = -1;
+ /**
+ *
+ * @param output
+ * @param conditional
+ * @throws LanguageException
+ */
+ private void validateGroupedAgg(DataIdentifier output, boolean conditional)
+ throws LanguageException
+ {
+ //check existing target and groups
if (getVarParam(Statement.GAGG_TARGET) == null || getVarParam(Statement.GAGG_GROUPS) == null){
- raiseValidateError("Must define both target and groups and both must have same dimensions", conditional);
+ raiseValidateError("Must define both target and groups.", conditional);
}
- if (getVarParam(Statement.GAGG_TARGET) instanceof DataIdentifier && getVarParam(Statement.GAGG_GROUPS) instanceof DataIdentifier && (getVarParam(Statement.GAGG_WEIGHTS) == null || getVarParam(Statement.GAGG_WEIGHTS) instanceof DataIdentifier))
- {
-
- DataIdentifier targetid = (DataIdentifier)getVarParam(Statement.GAGG_TARGET);
- DataIdentifier groupsid = (DataIdentifier)getVarParam(Statement.GAGG_GROUPS);
- DataIdentifier weightsid = (DataIdentifier)getVarParam(Statement.GAGG_WEIGHTS);
- if ( targetid.dimsKnown() ) {
- colwise = targetid.getDim1() > targetid.getDim2() ? 1 : 0;
- }
- else if ( groupsid.dimsKnown() ) {
- colwise = groupsid.getDim1() > groupsid.getDim2() ? 1 : 0;
+ Expression exprTarget = getVarParam(Statement.GAGG_TARGET);
+ Expression exprGroups = getVarParam(Statement.GAGG_GROUPS);
+ Expression exprNGroups = getVarParam(Statement.GAGG_NUM_GROUPS);
+
+ //check valid input dimensions
+ boolean colwise = true;
+ boolean matrix = false;
+ if( exprGroups.getOutput().dimsKnown() && exprTarget.getOutput().dimsKnown() )
+ {
+ //check for valid matrix input
+ if( exprGroups.getOutput().getDim2()==1 && exprTarget.getOutput().getDim2()>1 )
+ {
+ if( getVarParam(Statement.GAGG_WEIGHTS) != null ) {
+ raiseValidateError("Matrix input not supported with weights.", conditional);
+ }
+ if( getVarParam(Statement.GAGG_NUM_GROUPS) == null ) {
+ raiseValidateError("Matrix input not supported without specified numgroups.", conditional);
+ }
+ if( exprGroups.getOutput().getDim1() != exprTarget.getOutput().getDim1() ) {
+ raiseValidateError("Target and groups must have same dimensions -- " + " target dims: " +
+ exprTarget.getOutput().getDim1() +" x "+exprTarget.getOutput().getDim2()+", groups dims: " + exprGroups.getOutput().getDim1() + " x 1.", conditional);
+ }
+ matrix = true;
}
- else if ( weightsid != null && weightsid.dimsKnown() ) {
- colwise = weightsid.getDim1() > weightsid.getDim2() ? 1 : 0;
+ //check for valid col vector input
+ else if( exprGroups.getOutput().getDim2()==1 && exprTarget.getOutput().getDim2()==1 )
+ {
+ if( exprGroups.getOutput().getDim1() != exprTarget.getOutput().getDim1() ) {
+ raiseValidateError("Target and groups must have same dimensions -- " + " target dims: " +
+ exprTarget.getOutput().getDim1() +" x 1, groups dims: " + exprGroups.getOutput().getDim1() + " x 1.", conditional);
+ }
}
-
- //precompute number of rows and columns because target can be row or column vector
- long rowsTarget = targetid.getDim1(); // Math.max(targetid.getDim1(),targetid.getDim2());
- long colsTarget = targetid.getDim2(); // Math.min(targetid.getDim1(),targetid.getDim2());
-
- if( targetid.dimsKnown() && groupsid.dimsKnown() &&
- (rowsTarget != groupsid.getDim1() || colsTarget != groupsid.getDim2()) )
- {
- raiseValidateError("target and groups must have same dimensions -- "
- + " targetid dims: " + targetid.getDim1() +" rows, " + targetid.getDim2() + " cols -- groupsid dims: " + groupsid.getDim1() + " rows, " + groupsid.getDim2() + " cols ", conditional);
+ //check for valid row vector input
+ else if( exprGroups.getOutput().getDim1()==1 && exprTarget.getOutput().getDim1()==1 )
+ {
+ if( exprGroups.getOutput().getDim2() != exprTarget.getOutput().getDim2() ) {
+ raiseValidateError("Target and groups must have same dimensions -- " + " target dims: " +
+ "1 x " + exprTarget.getOutput().getDim2() +", groups dims: 1 x " + exprGroups.getOutput().getDim2() + ".", conditional);
+ }
+ colwise = true;
}
-
- if( weightsid != null && (targetid.dimsKnown() && weightsid.dimsKnown()) &&
- (rowsTarget != weightsid.getDim1() || colsTarget != weightsid.getDim2() ))
- {
- raiseValidateError("target and weights must have same dimensions -- "
- + " targetid dims: " + targetid.getDim1() +" rows, " + targetid.getDim2() + " cols -- weightsid dims: " + weightsid.getDim1() + " rows, " + weightsid.getDim2() + " cols ", conditional);
+ else {
+ raiseValidateError("Invalid target and groups inputs - dimension mismatch.", conditional);
}
}
- if (getVarParam(Statement.GAGG_FN) == null){
+ //check function parameter
+ Expression functParam = getVarParam(Statement.GAGG_FN);
+ if( functParam == null ) {
raiseValidateError("must define function name (fn=<function name>) for aggregate()", conditional);
}
-
- Expression functParam = getVarParam(Statement.GAGG_FN);
-
- if (functParam instanceof Identifier)
+ else if (functParam instanceof Identifier)
{
// standardize to lowercase and dequote fname
- String fnameStr = getVarParam(Statement.GAGG_FN).toString();
-
+ String fnameStr = functParam.toString();
// check that IF fname="centralmoment" THEN order=m is defined, where m=2,3,4
// check ELSE IF fname is allowed
@@ -474,25 +491,25 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
}
}
- Expression ngroupsParam = getVarParam(Statement.GAGG_NUM_GROUPS);
+ //determine output dimensions
long outputDim1 = -1, outputDim2 = -1;
- if( ngroupsParam != null && ngroupsParam instanceof Identifier )
+ if( exprNGroups != null && exprNGroups instanceof Identifier )
{
- Identifier numGroups = (Identifier) ngroupsParam;
+ Identifier numGroups = (Identifier) exprNGroups;
if ( numGroups != null && numGroups instanceof ConstIdentifier) {
long ngroups = ((ConstIdentifier)numGroups).getLongValue();
- if ( colwise == 1 ) {
+ if ( colwise ) {
outputDim1 = ngroups;
- outputDim2 = 1;
+ outputDim2 = matrix ? exprTarget.getOutput().getDim2() : 1;
}
- else if ( colwise == 0 ) {
- outputDim1 = 1;
+ else {
+ outputDim1 = 1; //no support for matrix
outputDim2 = ngroups;
}
}
}
- // Output is a matrix with unknown dims
+ //set output meta data
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.DOUBLE);
output.setDimensions(outputDim1, outputDim2);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/runtime/instructions/mr/GroupedAggregateInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/GroupedAggregateInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/GroupedAggregateInstruction.java
index fe5ee86..a104ca8 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/mr/GroupedAggregateInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/GroupedAggregateInstruction.java
@@ -37,12 +37,27 @@ import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTy
public class GroupedAggregateInstruction extends UnaryMRInstructionBase
{
+ private boolean _weights = false;
+ private long _bclen = -1;
-
- public GroupedAggregateInstruction(Operator op, byte in, byte out, String istr) {
+ public GroupedAggregateInstruction(Operator op, byte in, byte out, boolean weights, String istr) {
super(op, in, out);
mrtype = MRINSTRUCTION_TYPE.GroupedAggregate;
instString = istr;
+
+ _weights = weights;
+ }
+
+ public boolean hasWeights() {
+ return _weights;
+ }
+
+ public void setBclen(long bclen){
+ _bclen = bclen;
+ }
+
+ public long getBclen(){
+ return _bclen;
}
@Override
@@ -63,14 +78,15 @@ public class GroupedAggregateInstruction extends UnaryMRInstructionBase
byte in, out;
String opcode = parts[0];
in = Byte.parseByte(parts[1]);
- out = Byte.parseByte(parts[parts.length - 1]);
+ out = Byte.parseByte(parts[parts.length - 2]);
+ boolean weights = Boolean.parseBoolean(parts[parts.length-1]);
if ( !opcode.equalsIgnoreCase("groupedagg") ) {
throw new DMLRuntimeException("Invalid opcode in GroupedAggregateInstruction: " + opcode);
}
Operator optr = parseGroupedAggOperator(parts[2], parts[3]);
- return new GroupedAggregateInstruction(optr, in, out, str);
+ return new GroupedAggregateInstruction(optr, in, out, weights, str);
}
public static Operator parseGroupedAggOperator(String fn, String other) throws DMLRuntimeException {
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 84b4821..75da3b9 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -25,7 +25,6 @@ import java.util.HashMap;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
-import org.apache.spark.api.java.function.PairFunction;
import scala.Tuple2;
@@ -185,12 +184,12 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
MatrixCharacteristics mc1 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_TARGET) );
MatrixCharacteristics mc2 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_GROUPS) );
- if(mc1.dimsKnown() && mc2.dimsKnown() && (mc1.getRows() != mc2.getRows() || mc1.getCols() != mc2.getCols())) {
- throw new DMLRuntimeException("Grouped Aggregate SPInstruction is not supported for dimension of target != groups");
+ if(mc1.dimsKnown() && mc2.dimsKnown() && (mc1.getRows() != mc2.getRows() || mc2.getCols() !=1)) {
+ throw new DMLRuntimeException("Grouped Aggregate dimension mismatch between target and groups.");
}
MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(output.getName());
- JavaPairRDD<Long, WeightedCell> groupWeightedCells = null;
+ JavaPairRDD<MatrixIndexes, WeightedCell> groupWeightedCells = null;
// Step 1: First extract groupWeightedCells from group, target and weights
if ( params.get(Statement.GAGG_WEIGHTS) != null ) {
@@ -198,15 +197,22 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
MatrixCharacteristics mc3 = sec.getMatrixCharacteristics( params.get(Statement.GAGG_GROUPS) );
if(mc1.dimsKnown() && mc3.dimsKnown() && (mc1.getRows() != mc3.getRows() || mc1.getCols() != mc3.getCols())) {
- throw new DMLRuntimeException("Grouped Aggregate SPInstruction is not supported for dimension of target != weights");
+ throw new DMLRuntimeException("Grouped Aggregate dimension mismatch between target, groups, and weights.");
}
groupWeightedCells = groups.join(target).join(weights)
.flatMapToPair(new ExtractGroupNWeights());
}
- else {
+ else //input vector or matrix
+ {
+ //replicate groups if necessary
+ if( mc1.getNumColBlocks() > 1 ) {
+ groups = groups.flatMapToPair(
+ new ReplicateVectorFunction(false, mc1.getNumColBlocks() ));
+ }
+
groupWeightedCells = groups.join(target)
- .flatMapToPair(new ExtractGroup());
+ .flatMapToPair(new ExtractGroup(mc1.getColsPerBlock()));
}
// Step 2: Make sure we have brlen required while creating <MatrixIndexes, MatrixCell>
@@ -219,21 +225,20 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
JavaPairRDD<MatrixIndexes, MatrixCell> out = null;
if(_optr instanceof CMOperator && ((CMOperator) _optr).isPartialAggregateOperator() ) {
out = groupWeightedCells.reduceByKey(new PerformGroupByAggInCombiner(_optr))
- .mapToPair(new CreateMatrixCell(brlen, _optr));
+ .mapValues(new CreateMatrixCell(brlen, _optr));
}
else {
// Use groupby key because partial aggregation is not supported
out = groupWeightedCells.groupByKey()
- .mapToPair(new PerformGroupByAggInReducer(_optr))
- .mapToPair(new CreateMatrixCell(brlen, _optr));
+ .mapValues(new PerformGroupByAggInReducer(_optr))
+ .mapValues(new CreateMatrixCell(brlen, _optr));
}
// Step 4: Set output characteristics and rdd handle
setOutputCharacteristicsForGroupedAgg(mc1, mcOut, out);
//store output rdd handle
- sec.setRDDHandleForVariable(output.getName(), out);
-
+ sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), params.get(Statement.GAGG_TARGET) );
sec.addLineageRDD(output.getName(), params.get(Statement.GAGG_GROUPS) );
if ( params.get(Statement.GAGG_WEIGHTS) != null ) {
@@ -499,7 +504,7 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
/**
*
*/
- public static class CreateMatrixCell implements PairFunction<Tuple2<Long,WeightedCell>, MatrixIndexes, MatrixCell>
+ public static class CreateMatrixCell implements Function<WeightedCell, MatrixCell>
{
private static final long serialVersionUID = -5783727852453040737L;
@@ -510,8 +515,9 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
}
@Override
- public Tuple2<MatrixIndexes, MatrixCell> call(Tuple2<Long, WeightedCell> kv) throws Exception {
-
+ public MatrixCell call(WeightedCell kv)
+ throws Exception
+ {
double val = -1;
if(op instanceof CMOperator)
{
@@ -519,23 +525,22 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
switch(agg)
{
case COUNT:
- val = kv._2.getWeight();
+ val = kv.getWeight();
break;
case MEAN:
- val = kv._2.getValue();
+ val = kv.getValue();
break;
case CM2:
- val = kv._2.getValue()/ kv._2.getWeight();
+ val = kv.getValue()/ kv.getWeight();
break;
case CM3:
- val = kv._2.getValue()/ kv._2.getWeight();
+ val = kv.getValue()/ kv.getWeight();
break;
case CM4:
- val = kv._2.getValue()/ kv._2.getWeight();
+ val = kv.getValue()/ kv.getWeight();
break;
case VARIANCE:
- val = kv._2.getValue()/kv._2.getWeight();
- // val = kv._2.getWeight() ==1.0? 0:kv._2.getValue()/(kv._2.getWeight() - 1);
+ val = kv.getValue()/kv.getWeight();
break;
default:
throw new DMLRuntimeException("Invalid aggreagte in CM_CV_Object: " + agg);
@@ -544,15 +549,11 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
else
{
//avoid division by 0
- val = kv._2.getValue()/kv._2.getWeight();
+ val = kv.getValue()/kv.getWeight();
}
- MatrixIndexes indx = new MatrixIndexes(kv._1, 1);
- MatrixCell cell = new MatrixCell(val);
-
- return new Tuple2<MatrixIndexes, MatrixCell>(indx, cell);
+ return new MatrixCell(val);
}
-
}
/**
@@ -572,14 +573,13 @@ public class ParameterizedBuiltinSPInstruction extends ComputationSPInstruction
if ( params.get(Statement.GAGG_NUM_GROUPS) != null) {
int ngroups = (int) Double.parseDouble(params.get(Statement.GAGG_NUM_GROUPS));
- mcOut.set(ngroups, 1, mc1.getRowsPerBlock(), mc1.getColsPerBlock());
+ mcOut.set(ngroups, mc1.getCols(), -1, -1); //grouped aggregate with cell output
}
else {
out = SparkUtils.cacheBinaryCellRDD(out);
mcOut.set(SparkUtils.computeMatrixCharacteristics(out));
- mcOut.setBlockSize(mc1.getRowsPerBlock(), mc1.getColsPerBlock());
+ mcOut.setBlockSize(-1, -1); //grouped aggregate with cell output
}
}
}
-
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
index fcd0166..6259955 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroup.java
@@ -29,15 +29,22 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
import org.apache.sysml.runtime.util.UtilFunctions;
-public class ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes,Tuple2<MatrixBlock, MatrixBlock>>, Long, WeightedCell> {
+public class ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes,Tuple2<MatrixBlock, MatrixBlock>>, MatrixIndexes, WeightedCell> {
private static final long serialVersionUID = -7059358143841229966L;
+ private long _bclen = -1;
+
+ public ExtractGroup( long bclen ) {
+ _bclen = bclen;
+ }
+
@Override
- public Iterable<Tuple2<Long, WeightedCell>> call(
+ public Iterable<Tuple2<MatrixIndexes, WeightedCell>> call(
Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg)
throws Exception
{
+ MatrixIndexes ix = arg._1;
MatrixBlock group = arg._2._1;
MatrixBlock target = arg._2._2;
@@ -47,15 +54,20 @@ public class ExtractGroup implements PairFlatMapFunction<Tuple2<MatrixIndexes,T
}
//output weighted cells
- ArrayList<Tuple2<Long, WeightedCell>> groupValuePairs = new ArrayList<Tuple2<Long, WeightedCell>>();
+ ArrayList<Tuple2<MatrixIndexes, WeightedCell>> groupValuePairs = new ArrayList<Tuple2<MatrixIndexes, WeightedCell>>();
+ long coloff = (ix.getColumnIndex()-1)*_bclen;
for(int i = 0; i < group.getNumRows(); i++) {
- WeightedCell weightedCell = new WeightedCell();
- weightedCell.setValue(target.quickGetValue(i, 0));
long groupVal = UtilFunctions.toLong(group.quickGetValue(i, 0));
if(groupVal < 1) {
throw new Exception("Expected group values to be greater than equal to 1 but found " + groupVal);
}
- groupValuePairs.add(new Tuple2<Long, WeightedCell>(groupVal, weightedCell));
+ for( int j=0; j<target.getNumColumns(); j++ ) {
+ WeightedCell weightedCell = new WeightedCell();
+ weightedCell.setValue(target.quickGetValue(i, j));
+ weightedCell.setWeight(1);
+ MatrixIndexes ixout = new MatrixIndexes(groupVal,coloff+j+1);
+ groupValuePairs.add(new Tuple2<MatrixIndexes, WeightedCell>(ixout, weightedCell));
+ }
}
return groupValuePairs;
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroupNWeights.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroupNWeights.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroupNWeights.java
index 17c58c5..372ebea 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroupNWeights.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractGroupNWeights.java
@@ -30,12 +30,12 @@ import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
import org.apache.sysml.runtime.util.UtilFunctions;
-public class ExtractGroupNWeights implements PairFlatMapFunction<Tuple2<MatrixIndexes,Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>>, Long, WeightedCell> {
+public class ExtractGroupNWeights implements PairFlatMapFunction<Tuple2<MatrixIndexes,Tuple2<Tuple2<MatrixBlock,MatrixBlock>,MatrixBlock>>, MatrixIndexes, WeightedCell> {
private static final long serialVersionUID = -188180042997588072L;
@Override
- public Iterable<Tuple2<Long, WeightedCell>> call(
+ public Iterable<Tuple2<MatrixIndexes, WeightedCell>> call(
Tuple2<MatrixIndexes, Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>> arg)
throws Exception
{
@@ -49,7 +49,7 @@ public class ExtractGroupNWeights implements PairFlatMapFunction<Tuple2<MatrixIn
}
//output weighted cells
- ArrayList<Tuple2<Long, WeightedCell>> groupValuePairs = new ArrayList<Tuple2<Long, WeightedCell>>();
+ ArrayList<Tuple2<MatrixIndexes, WeightedCell>> groupValuePairs = new ArrayList<Tuple2<MatrixIndexes, WeightedCell>>();
for(int i = 0; i < group.getNumRows(); i++) {
WeightedCell weightedCell = new WeightedCell();
weightedCell.setValue(target.quickGetValue(i, 0));
@@ -58,7 +58,8 @@ public class ExtractGroupNWeights implements PairFlatMapFunction<Tuple2<MatrixIn
if(groupVal < 1) {
throw new Exception("Expected group values to be greater than equal to 1 but found " + groupVal);
}
- groupValuePairs.add(new Tuple2<Long, WeightedCell>(groupVal, weightedCell));
+ MatrixIndexes ix = new MatrixIndexes(groupVal, 1);
+ groupValuePairs.add(new Tuple2<MatrixIndexes, WeightedCell>(ix, weightedCell));
}
return groupValuePairs;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInReducer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInReducer.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInReducer.java
index eb5d104..2d9f40b 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInReducer.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/PerformGroupByAggInReducer.java
@@ -19,9 +19,7 @@
package org.apache.sysml.runtime.instructions.spark.functions;
-import org.apache.spark.api.java.function.PairFunction;
-
-import scala.Tuple2;
+import org.apache.spark.api.java.function.Function;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.functionobjects.CM;
@@ -33,8 +31,8 @@ import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
-public class PerformGroupByAggInReducer implements PairFunction<Tuple2<Long,Iterable<WeightedCell>>, Long, WeightedCell> {
-
+public class PerformGroupByAggInReducer implements Function<Iterable<WeightedCell>, WeightedCell>
+{
private static final long serialVersionUID = 8160556441153227417L;
Operator op;
@@ -43,32 +41,21 @@ public class PerformGroupByAggInReducer implements PairFunction<Tuple2<Long,Iter
}
@Override
- public Tuple2<Long, WeightedCell> call(
- Tuple2<Long, Iterable<WeightedCell>> kv) throws Exception {
- return new Tuple2<Long, WeightedCell>(kv._1, doAggregation(op, kv._2));
- }
-
- public WeightedCell doAggregation(Operator op, Iterable<WeightedCell> values) throws DMLRuntimeException {
+ public WeightedCell call(Iterable<WeightedCell> kv)
+ throws Exception
+ {
WeightedCell outCell = new WeightedCell();
CM_COV_Object cmObj = new CM_COV_Object();
if(op instanceof CMOperator) //everything except sum
{
cmObj.reset();
CM lcmFn = CM.getCMFnObject(((CMOperator) op).aggOpType); // cmFn.get(key.getTag());
- if( ((CMOperator) op).isPartialAggregateOperator() )
- {
+ if( ((CMOperator) op).isPartialAggregateOperator() ) {
throw new DMLRuntimeException("Incorrect usage, should have used PerformGroupByAggInCombiner");
-
-// //partial aggregate cm operator
-// for(WeightedCell value : values)
-// lcmFn.execute(cmObj, value.getValue(), value.getWeight());
-//
-// outCell.setValue(cmObj.getRequiredPartialResult(op));
-// outCell.setWeight(cmObj.getWeight());
}
else //forward tuples to reducer
{
- for(WeightedCell value : values)
+ for(WeightedCell value : kv)
lcmFn.execute(cmObj, value.getValue(), value.getWeight());
outCell.setValue(cmObj.getRequiredResult(op));
@@ -85,7 +72,7 @@ public class PerformGroupByAggInReducer implements PairFunction<Tuple2<Long,Iter
KahanPlus.getKahanPlusFnObject();
//partial aggregate with correction
- for(WeightedCell value : values)
+ for(WeightedCell value : kv)
aggop.increOp.fn.execute(buffer, value.getValue()*value.getWeight());
outCell.setValue(buffer._sum);
@@ -96,7 +83,7 @@ public class PerformGroupByAggInReducer implements PairFunction<Tuple2<Long,Iter
double v = aggop.initialValue;
//partial aggregate without correction
- for(WeightedCell value : values)
+ for(WeightedCell value : kv)
v=aggop.increOp.fn.execute(v, value.getValue()*value.getWeight());
outCell.setValue(v);
@@ -108,5 +95,4 @@ public class PerformGroupByAggInReducer implements PairFunction<Tuple2<Long,Iter
return outCell;
}
-
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/runtime/matrix/GroupedAggMR.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/GroupedAggMR.java b/src/main/java/org/apache/sysml/runtime/matrix/GroupedAggMR.java
index c327692..7a253f4 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/GroupedAggMR.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/GroupedAggMR.java
@@ -26,11 +26,10 @@ import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.RunningJob;
import org.apache.hadoop.mapred.Counters.Group;
-
import org.apache.sysml.runtime.instructions.MRJobInstruction;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
-import org.apache.sysml.runtime.matrix.data.TaggedInt;
+import org.apache.sysml.runtime.matrix.data.TaggedMatrixIndexes;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
import org.apache.sysml.runtime.matrix.mapred.GroupedAggMRCombiner;
import org.apache.sysml.runtime.matrix.mapred.GroupedAggMRMapper;
@@ -116,7 +115,7 @@ public class GroupedAggMR
// configure mapper and the mapper output key value pairs
job.setMapperClass(GroupedAggMRMapper.class);
job.setCombinerClass(GroupedAggMRCombiner.class);
- job.setMapOutputKeyClass(TaggedInt.class);
+ job.setMapOutputKeyClass(TaggedMatrixIndexes.class);
job.setMapOutputValueClass(WeightedCell.class);
//configure reducer
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index b7819b8..0fe9095 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -5295,13 +5295,14 @@ public class MatrixBlock extends MatrixValue implements Externalizable
MatrixBlock weights = checkType(wghts);
//check valid dimensions
+ boolean validMatrixOp = (weights == null && ngroups>=1);
if( this.getNumColumns() != 1 || (weights!=null && weights.getNumColumns()!=1) )
throw new DMLRuntimeException("groupedAggregate can only operate on 1-dimensional column matrices for groups and weights.");
- if( target.getNumColumns() != 1 && op instanceof CMOperator )
+ if( target.getNumColumns() != 1 && op instanceof CMOperator && !validMatrixOp )
throw new DMLRuntimeException("groupedAggregate can only operate on 1-dimensional column matrices for target (for this aggregation function).");
- if( target.getNumColumns() != 1 && target.getNumRows()!=1 )
+ if( target.getNumColumns() != 1 && target.getNumRows()!=1 && !validMatrixOp )
throw new DMLRuntimeException("groupedAggregate can only operate on 1-dimensional column or row matrix for target.");
- if( this.getNumRows() != Math.max(target.getNumRows(),target.getNumColumns()) || (weights != null && this.getNumRows() != weights.getNumRows()) )
+ if( this.getNumRows() != target.getNumRows() && this.getNumRows() !=Math.max(target.getNumRows(),target.getNumColumns()) || (weights != null && this.getNumRows() != weights.getNumRows()) )
throw new DMLRuntimeException("groupedAggregate can only operate on matrices with equal dimensions.");
// obtain numGroups from instruction, if provided
@@ -5323,40 +5324,22 @@ public class MatrixBlock extends MatrixValue implements Externalizable
}
// Allocate result matrix
+ boolean rowVector = (target.getNumRows()==1 && target.getNumColumns()>1);
MatrixBlock result = checkType(ret);
boolean result_sparsity = estimateSparsityOnGroupedAgg(rlen, numGroups);
if(result==null)
- result=new MatrixBlock(numGroups, 1, result_sparsity);
+ result=new MatrixBlock(numGroups, rowVector?1:target.getNumRows(), result_sparsity);
else
- result.reset(numGroups, 1, result_sparsity);
+ result.reset(numGroups, rowVector?1:target.getNumRows(), result_sparsity);
- // Compute the result
- double w = 1; // default weight
-
//CM operator for count, mean, variance
//note: current support only for column vectors
- if(op instanceof CMOperator) {
+ if(op instanceof CMOperator)
+ {
// initialize required objects for storing the result of CM operations
- CM cmFn = CM.getCMFnObject(((CMOperator) op).getAggOpType());
- CM_COV_Object[] cmValues = new CM_COV_Object[numGroups];
- for ( int i=0; i < numGroups; i++ )
- cmValues[i] = new CM_COV_Object();
-
- for ( int i=0; i < this.getNumRows(); i++ ) {
- int g = (int) this.quickGetValue(i, 0);
- if ( g > numGroups )
- continue;
- double d = target.quickGetValue(i,0);
- if ( weights != null )
- w = weights.quickGetValue(i,0);
- // cmValues is 0-indexed, whereas range of values for g = [1,numGroups]
- cmFn.execute(cmValues[g-1], d, w);
- }
+ CMOperator cmOp = (CMOperator) op;
- // extract the required value from each CM_COV_Object
- for ( int i=0; i < numGroups; i++ )
- // result is 0-indexed, so is cmValues
- result.quickSetValue(i, 0, cmValues[i].getRequiredResult(op));
+ groupedAggregateCM(target, weights, result, cmOp);
}
//Aggregate operator for sum (via kahan sum)
//note: support for row/column vectors and dense/sparse
@@ -5387,9 +5370,11 @@ public class MatrixBlock extends MatrixValue implements Externalizable
* @param op
* @throws DMLRuntimeException
*/
- private void groupedAggregateKahanPlus( MatrixBlock target, MatrixBlock weights, MatrixBlock result, AggregateOperator aggop ) throws DMLRuntimeException
+ private void groupedAggregateKahanPlus( MatrixBlock target, MatrixBlock weights, MatrixBlock result, AggregateOperator aggop )
+ throws DMLRuntimeException
{
- boolean rowVector = target.getNumColumns()>1;
+ boolean rowVector = (target.getNumRows()==1 && target.getNumColumns()>1);
+ int numCols = (!rowVector) ? target.getNumColumns() : 1;
double w = 1; //default weight
//skip empty blocks (sparse-safe operation)
@@ -5397,9 +5382,10 @@ public class MatrixBlock extends MatrixValue implements Externalizable
return;
//init group buffers
- KahanObject[] buffer = new KahanObject[numGroups];
- for(int i=0; i < numGroups; i++ )
- buffer[i] = new KahanObject(aggop.initialValue, 0);
+ KahanObject[][] buffer = new KahanObject[numGroups][numCols];
+ for( int i=0; i<numGroups; i++ )
+ for( int j=0; j<numCols; j++ )
+ buffer[i][j] = new KahanObject(aggop.initialValue, 0);
if( rowVector ) //target is rowvector
{
@@ -5417,7 +5403,7 @@ public class MatrixBlock extends MatrixValue implements Externalizable
continue;
if ( weights != null )
w = weights.quickGetValue(aix[j],0);
- aggop.increOp.fn.execute(buffer[g-1], avals[j]*w);
+ aggop.increOp.fn.execute(buffer[g-1][0], avals[j]*w);
}
}
@@ -5434,34 +5420,141 @@ public class MatrixBlock extends MatrixValue implements Externalizable
if ( weights != null )
w = weights.quickGetValue(i,0);
// buffer is 0-indexed, whereas range of values for g = [1,numGroups]
- aggop.increOp.fn.execute(buffer[g-1], d*w);
+ aggop.increOp.fn.execute(buffer[g-1][0], d*w);
}
}
}
}
- else //column vector (always dense, but works for sparse as well)
+ else //column vector or matrix
{
- for ( int i=0; i < this.getNumRows(); i++ )
+ if( target.sparse ) //SPARSE target
{
- double d = target.quickGetValue(i,0);
- if( d != 0 ) //sparse-safe
+ SparseRow[] a = target.sparseRows;
+
+ for( int i=0; i < getNumRows(); i++ )
{
int g = (int) this.quickGetValue(i, 0);
if ( g > numGroups )
continue;
+
+ if( a[i] != null && !a[i].isEmpty() )
+ {
+ int len = a[i].size();
+ int[] aix = a[i].getIndexContainer();
+ double[] avals = a[i].getValueContainer();
+ for( int j=0; j<len; j++ ) //for each nnz
+ {
+ if ( weights != null )
+ w = weights.quickGetValue(aix[j],0);
+ aggop.increOp.fn.execute(buffer[g-1][aix[j]], avals[j]*w);
+ }
+ }
+ }
+ }
+ else //DENSE target
+ {
+ double[] a = target.denseBlock;
+
+ for( int i=0, aix=0; i < getNumRows(); i++, aix+=numCols )
+ {
+ int g = (int) this.quickGetValue(i, 0);
+ if ( g > numGroups )
+ continue;
+
+ for( int j=0; j < numCols; j++ ) {
+ double d = a[ aix+j ];
+ if( d != 0 ) { //sparse-safe
+ if ( weights != null )
+ w = weights.quickGetValue(i,0);
+ // buffer is 0-indexed, whereas range of values for g = [1,numGroups]
+ aggop.increOp.fn.execute(buffer[g-1][j], d*w);
+ }
+ }
+ }
+ }
+ }
+
+ // extract the results from group buffers
+ for( int i=0; i < numGroups; i++ )
+ for( int j=0; j < numCols; j++ )
+ result.appendValue(i, j, buffer[i][j]._sum);
+ }
+
+ /**
+ *
+ * @param target
+ * @param weights
+ * @param result
+ * @param cmOp
+ * @throws DMLRuntimeException
+ */
+ private void groupedAggregateCM( MatrixBlock target, MatrixBlock weights, MatrixBlock result, CMOperator cmOp )
+ throws DMLRuntimeException
+ {
+ CM cmFn = CM.getCMFnObject(((CMOperator) cmOp).getAggOpType());
+ double w = 1; //default weight
+
+ //init group buffers
+ CM_COV_Object[][] cmValues = new CM_COV_Object[numGroups][target.clen];
+ for ( int i=0; i < numGroups; i++ )
+ for( int j=0; j < target.clen; j++ )
+ cmValues[i][j] = new CM_COV_Object();
+
+
+ //column vector or matrix
+ if( target.sparse ) //SPARSE target
+ {
+ SparseRow[] a = target.sparseRows;
+
+ for( int i=0; i < getNumRows(); i++ )
+ {
+ int g = (int) this.quickGetValue(i, 0);
+ if ( g > numGroups )
+ continue;
+
+ if( a[i] != null && !a[i].isEmpty() )
+ {
+ int len = a[i].size();
+ int[] aix = a[i].getIndexContainer();
+ double[] avals = a[i].getValueContainer();
+ for( int j=0; j<len; j++ ) //for each nnz
+ {
+ if ( weights != null )
+ w = weights.quickGetValue(aix[j],0);
+ cmFn.execute(cmValues[g-1][aix[j]], avals[j], w);
+ }
+ //TODO sparse unsafe correction
+ }
+ }
+ }
+ else //DENSE target
+ {
+ double[] a = target.denseBlock;
+
+ for( int i=0, aix=0; i < getNumRows(); i++, aix+=target.clen )
+ {
+ int g = (int) this.quickGetValue(i, 0);
+ if ( g > numGroups )
+ continue;
+
+ for( int j=0; j < target.clen; j++ ) {
+ double d = a[ aix+j ]; //sparse unsafe
if ( weights != null )
w = weights.quickGetValue(i,0);
// buffer is 0-indexed, whereas range of values for g = [1,numGroups]
- aggop.increOp.fn.execute(buffer[g-1], d*w);
+ cmFn.execute(cmValues[g-1][j], d, w);
}
}
}
- // extract the results from group buffers
- for ( int i=0; i < numGroups; i++ )
- result.quickSetValue(i, 0, buffer[i]._sum);
+ // extract the required value from each CM_COV_Object
+ for( int i=0; i < numGroups; i++ )
+ for( int j=0; j < target.clen; j++ ) {
+ // result is 0-indexed, so is cmValues
+ result.appendValue(i, j, cmValues[i][j].getRequiredResult(cmOp));
+ }
}
-
+
/**
*
* @param ret
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/runtime/matrix/data/TaggedMatrixIndexes.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/TaggedMatrixIndexes.java b/src/main/java/org/apache/sysml/runtime/matrix/data/TaggedMatrixIndexes.java
new file mode 100644
index 0000000..632de4c
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/TaggedMatrixIndexes.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+package org.apache.sysml.runtime.matrix.data;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.RawComparator;
+import org.apache.hadoop.io.WritableComparator;
+
+public class TaggedMatrixIndexes extends Tagged<MatrixIndexes>
+{
+ public TaggedMatrixIndexes(){}
+
+ public TaggedMatrixIndexes(MatrixIndexes ix, byte t) {
+ super(ix, t);
+ }
+
+ public TaggedMatrixIndexes(TaggedMatrixIndexes that) {
+ tag = that.tag;
+ base = that.base;
+ }
+
+ @Override
+ public String toString() {
+ return "k: "+base+", tag: "+tag;
+ }
+
+ public void readFields(DataInput in) throws IOException {
+ if( base == null ){
+ base = new MatrixIndexes();
+ }
+ base.readFields(in);
+ tag=in.readByte();
+ }
+
+ public void write(DataOutput out) throws IOException {
+ base.write(out);
+ out.writeByte(tag);
+ }
+
+ public int compareTo(TaggedMatrixIndexes other) {
+ int tmp = base.compareTo(other.base);
+ if( tmp != 0 )
+ return tmp;
+ else if( tag!=other.tag )
+ return tag-other.tag;
+ return 0;
+ }
+
+ @Override
+ public boolean equals(Object other)
+ {
+ if( !(other instanceof TaggedMatrixIndexes))
+ return false;
+
+ TaggedMatrixIndexes tother = (TaggedMatrixIndexes)other;
+ return (base.equals(tother.base) && tag==tother.tag);
+ }
+
+ @Override
+ public int hashCode() {
+ return base.hashCode() + tag;
+ }
+
+ public static class Comparator implements RawComparator<TaggedMatrixIndexes>
+ {
+ @Override
+ public int compare(byte[] b1, int s1, int l1,
+ byte[] b2, int s2, int l2)
+ {
+ return WritableComparator.compareBytes(b1, s1, l1, b2, s2, l2);
+ }
+
+ @Override
+ public int compare(TaggedMatrixIndexes m1, TaggedMatrixIndexes m2) {
+ return m1.compareTo(m2);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRCombiner.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRCombiner.java b/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRCombiner.java
index ce62993..e561f3c 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRCombiner.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRCombiner.java
@@ -27,13 +27,12 @@ import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
-
import org.apache.sysml.runtime.functionobjects.CM;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction;
-import org.apache.sysml.runtime.matrix.data.TaggedInt;
+import org.apache.sysml.runtime.matrix.data.TaggedMatrixIndexes;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
@@ -41,9 +40,8 @@ import org.apache.sysml.runtime.matrix.operators.Operator;
public class GroupedAggMRCombiner extends ReduceBase
- implements Reducer<TaggedInt, WeightedCell, TaggedInt, WeightedCell>
-{
-
+ implements Reducer<TaggedMatrixIndexes, WeightedCell, TaggedMatrixIndexes, WeightedCell>
+{
//grouped aggregate instructions
private HashMap<Byte, GroupedAggregateInstruction> grpaggInstructions = new HashMap<Byte, GroupedAggregateInstruction>();
@@ -53,8 +51,8 @@ public class GroupedAggMRCombiner extends ReduceBase
private WeightedCell outCell = new WeightedCell();
@Override
- public void reduce(TaggedInt key, Iterator<WeightedCell> values,
- OutputCollector<TaggedInt, WeightedCell> out, Reporter reporter)
+ public void reduce(TaggedMatrixIndexes key, Iterator<WeightedCell> values,
+ OutputCollector<TaggedMatrixIndexes, WeightedCell> out, Reporter reporter)
throws IOException
{
long start = System.currentTimeMillis();
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRMapper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRMapper.java b/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRMapper.java
index 5d92af1..079e1fb 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRMapper.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRMapper.java
@@ -23,34 +23,32 @@ package org.apache.sysml.runtime.matrix.mapred;
import java.io.IOException;
import java.util.ArrayList;
-import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.Mapper;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reporter;
-
import org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.data.MatrixValue;
-import org.apache.sysml.runtime.matrix.data.TaggedInt;
+import org.apache.sysml.runtime.matrix.data.TaggedMatrixIndexes;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
public class GroupedAggMRMapper extends MapperBase
- implements Mapper<MatrixIndexes, MatrixValue, TaggedInt, WeightedCell>
+ implements Mapper<MatrixIndexes, MatrixValue, TaggedMatrixIndexes, WeightedCell>
{
//block instructions that need to be performed in part by mapper
protected ArrayList<ArrayList<GroupedAggregateInstruction>> groupAgg_instructions=new ArrayList<ArrayList<GroupedAggregateInstruction>>();
- private IntWritable outKeyValue=new IntWritable();
- private TaggedInt outKey=new TaggedInt(outKeyValue, (byte)0);
+ private MatrixIndexes outKeyValue=new MatrixIndexes();
+ private TaggedMatrixIndexes outKey=new TaggedMatrixIndexes(outKeyValue, (byte)0);
private WeightedCell outValue=new WeightedCell();
@Override
public void map(MatrixIndexes key, MatrixValue value,
- OutputCollector<TaggedInt, WeightedCell> out, Reporter reporter)
+ OutputCollector<TaggedMatrixIndexes, WeightedCell> out, Reporter reporter)
throws IOException
{
for(int i=0; i<representativeMatrixes.size(); i++)
@@ -65,21 +63,25 @@ public class GroupedAggMRMapper extends MapperBase
int rlen = block.getNumRows();
int clen = block.getNumColumns();
- if( clen == 2 ) //w/o weights
+ if( !ins.hasWeights() ) //w/o weights (input vector or matrix)
{
- for( int r=0; r<rlen; r++ )
- {
- outKeyValue.set((int)block.quickGetValue(r, 1));
- outValue.setValue(block.quickGetValue(r, 0));
- outValue.setWeight(1);
- out.collect(outKey, outValue);
+ long coloff = (key.getColumnIndex()-1)*ins.getBclen();
+
+ for( int r=0; r<rlen; r++ ) {
+ int group = (int)block.quickGetValue(r, clen-1);
+ for( int c=0; c<clen-1; c++ ) {
+ outKeyValue.setIndexes(group, coloff+c+1);
+ outValue.setValue(block.quickGetValue(r, c));
+ outValue.setWeight(1);
+ out.collect(outKey, outValue);
+ }
}
}
- else //w/ weights
+ else //w/ weights (input vector)
{
for( int r=0; r<rlen; r++ )
{
- outKeyValue.set((int)block.quickGetValue(r, 1));
+ outKeyValue.setIndexes((int)block.quickGetValue(r, 1),1);
outValue.setValue(block.quickGetValue(r, 0));
outValue.setWeight(block.quickGetValue(r, 2));
out.collect(outKey, outValue);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRReducer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRReducer.java b/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRReducer.java
index 9ccf806..bebace2 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRReducer.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/mapred/GroupedAggMRReducer.java
@@ -29,14 +29,13 @@ import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
-
import org.apache.sysml.runtime.functionobjects.CM;
import org.apache.sysml.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.instructions.mr.GroupedAggregateInstruction;
import org.apache.sysml.runtime.matrix.data.MatrixCell;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
-import org.apache.sysml.runtime.matrix.data.TaggedInt;
+import org.apache.sysml.runtime.matrix.data.TaggedMatrixIndexes;
import org.apache.sysml.runtime.matrix.data.WeightedCell;
import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
@@ -44,7 +43,7 @@ import org.apache.sysml.runtime.matrix.operators.Operator;
public class GroupedAggMRReducer extends ReduceBase
- implements Reducer<TaggedInt, WeightedCell, MatrixIndexes, MatrixCell >
+ implements Reducer<TaggedMatrixIndexes, WeightedCell, MatrixIndexes, MatrixCell >
{
private MatrixIndexes outIndex=new MatrixIndexes(1, 1);
@@ -55,7 +54,7 @@ public class GroupedAggMRReducer extends ReduceBase
private HashMap<Byte, ArrayList<Integer>> outputIndexesMapping=new HashMap<Byte, ArrayList<Integer>>();
@Override
- public void reduce(TaggedInt key,Iterator<WeightedCell> values,
+ public void reduce(TaggedMatrixIndexes key,Iterator<WeightedCell> values,
OutputCollector<MatrixIndexes, MatrixCell> out, Reporter report)
throws IOException
{
@@ -112,7 +111,7 @@ public class GroupedAggMRReducer extends ReduceBase
throw new IOException(ex);
}
- outIndex.setIndexes((long)key.getBaseObject().get(), 1);
+ outIndex.setIndexes(key.getBaseObject());
cachedValues.reset();
cachedValues.set(key.getTag(), outIndex, outCell);
processReducerInstructions();
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/main/java/org/apache/sysml/runtime/matrix/mapred/MRJobConfiguration.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/mapred/MRJobConfiguration.java b/src/main/java/org/apache/sysml/runtime/matrix/mapred/MRJobConfiguration.java
index 669dfaf..5f40a65 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/mapred/MRJobConfiguration.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/mapred/MRJobConfiguration.java
@@ -922,11 +922,28 @@ public class MRJobConfiguration
String str=job.get(CM_N_COV_INSTRUCTIONS_CONFIG);
return MRInstructionParser.parseCM_N_COVInstructions(str);
}
+
+ /**
+ *
+ * @param job
+ * @return
+ * @throws DMLUnsupportedOperationException
+ * @throws DMLRuntimeException
+ */
public static GroupedAggregateInstruction[] getGroupedAggregateInstructions(JobConf job)
- throws DMLUnsupportedOperationException, DMLRuntimeException
+ throws DMLUnsupportedOperationException, DMLRuntimeException
{
+ //parse all grouped aggregate instructions
String str=job.get(GROUPEDAGG_INSTRUCTIONS_CONFIG);
- return MRInstructionParser.parseGroupedAggInstructions(str);
+ GroupedAggregateInstruction[] tmp = MRInstructionParser.parseGroupedAggInstructions(str);
+
+ //obtain bclen for all instructions
+ for( int i=0; i< tmp.length; i++ ) {
+ byte tag = tmp[i].input;
+ tmp[i].setBclen(getMatrixCharacteristicsForInput(job, tag).getColsPerBlock());
+ }
+
+ return tmp;
}
public static String[] getOutputs(JobConf job)
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/test/java/org/apache/sysml/test/integration/functions/aggregate/FullGroupedAggregateMatrixTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/FullGroupedAggregateMatrixTest.java b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/FullGroupedAggregateMatrixTest.java
new file mode 100644
index 0000000..ec1f989
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/FullGroupedAggregateMatrixTest.java
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.aggregate;
+
+import java.io.IOException;
+import java.util.HashMap;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.apache.sysml.api.DMLException;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.parser.Expression.ValueType;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.runtime.util.MapReduceTool;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+/**
+ *
+ */
+public class FullGroupedAggregateMatrixTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME1 = "GroupedAggregateMatrix";
+ private final static String TEST_NAME2 = "GroupedAggregateMatrixNoDims";
+
+ private final static String TEST_DIR = "functions/aggregate/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FullGroupedAggregateMatrixTest.class.getSimpleName() + "/";
+ private final static double eps = 1e-10;
+
+ private final static int rows = 1765;
+ private final static int cols = 19;
+ private final static int cols2 = 1007;
+
+ private final static double sparsity1 = 0.1;
+ private final static double sparsity2 = 0.7;
+
+ private final static int numGroups = 17;
+
+ private enum OpType{
+ SUM,
+ COUNT,
+ MEAN,
+ VARIANCE,
+ MOMENT3,
+ MOMENT4,
+ }
+
+
+ @Override
+ public void setUp()
+ {
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"C"}));
+ addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"C"}));
+ TestUtils.clearAssertionInformation();
+
+ if (TEST_CACHE_ENABLED) {
+ setOutAndExpectedDeletionDisabled(true);
+ }
+ }
+
+ @BeforeClass
+ public static void init()
+ {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+
+ @AfterClass
+ public static void cleanUp()
+ {
+ if (TEST_CACHE_ENABLED) {
+ TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+ }
+ }
+
+ //CP testcases
+
+ @Test
+ public void testGroupedAggSumDenseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, false, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggSumSparseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, true, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggCountDenseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.COUNT, false, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggCountSparseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.COUNT, true, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMeanDenseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MEAN, false, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMeanSparseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MEAN, true, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggVarDenseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.VARIANCE, false, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggVarSparseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.VARIANCE, true, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMoment3DenseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MOMENT3, false, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMoment3SparseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MOMENT3, true, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMoment4DenseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MOMENT4, false, ExecType.CP);
+ }
+
+ @Test
+ public void testGroupedAggMoment4SparseCP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MOMENT4, true, ExecType.CP);
+ }
+
+ //special CP testcases (negative)
+
+ @Test
+ public void testGroupedAggSumDenseCPNoDims() {
+ runGroupedAggregateOperationTest(TEST_NAME2, OpType.SUM, false, ExecType.CP);
+ }
+
+
+ //SP testcases
+
+ @Test
+ public void testGroupedAggSumDenseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggSumSparseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggCountDenseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.COUNT, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggCountSparseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.COUNT, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMeanDenseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MEAN, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMeanSparseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MEAN, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggVarDenseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.VARIANCE, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggVarSparseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.VARIANCE, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMoment3DenseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MOMENT3, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMoment3SparseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MOMENT3, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMoment4DenseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MOMENT4, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggMoment4SparseSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MOMENT4, true, ExecType.SPARK);
+ }
+
+ @Test
+ public void testGroupedAggSumDenseWideSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, false, ExecType.SPARK, cols2);
+ }
+
+ @Test
+ public void testGroupedAggSumSparseWideSP() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, true, ExecType.SPARK, cols2);
+ }
+
+
+ //MR testcases
+
+ @Test
+ public void testGroupedAggSumDenseMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, false, ExecType.MR);
+ }
+
+ @Test
+ public void testGroupedAggSumSparseMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, true, ExecType.MR);
+ }
+
+ @Test
+ public void testGroupedAggCountDenseMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.COUNT, false, ExecType.MR);
+ }
+
+ @Test
+ public void testGroupedAggCountSparseMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.COUNT, true, ExecType.MR);
+ }
+
+ @Test
+ public void testGroupedAggMeanDenseMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MEAN, false, ExecType.MR);
+ }
+
+ @Test
+ public void testGroupedAggMeanSparseMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MEAN, true, ExecType.MR);
+ }
+
+ @Test
+ public void testGroupedAggVarDenseMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.VARIANCE, false, ExecType.MR);
+ }
+
+ @Test
+ public void testGroupedAggVarSparseMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.VARIANCE, true, ExecType.MR);
+ }
+
+ @Test
+ public void testGroupedAggMoment3DenseMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MOMENT3, false, ExecType.MR);
+ }
+
+ @Test
+ public void testGroupedAggMoment3SparseMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MOMENT3, true, ExecType.MR);
+ }
+
+ @Test
+ public void testGroupedAggMoment4DenseMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MOMENT4, false, ExecType.MR);
+ }
+
+ @Test
+ public void testGroupedAggMoment4SparseMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.MOMENT4, true, ExecType.MR);
+ }
+
+ @Test
+ public void testGroupedAggSumDenseWideMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, false, ExecType.MR, cols2);
+ }
+
+ @Test
+ public void testGroupedAggSumSparseWideMR() {
+ runGroupedAggregateOperationTest(TEST_NAME1, OpType.SUM, true, ExecType.MR, cols2);
+ }
+
+ /**
+ *
+ * @param testname
+ * @param type
+ * @param sparse
+ * @param instType
+ */
+ private void runGroupedAggregateOperationTest( String testname, OpType type, boolean sparse, ExecType instType) {
+ runGroupedAggregateOperationTest(testname, type, sparse, instType, cols);
+ }
+
+ /**
+ *
+ * @param testname
+ * @param type
+ * @param sparse
+ * @param instType
+ */
+ @SuppressWarnings("rawtypes")
+ private void runGroupedAggregateOperationTest( String testname, OpType type, boolean sparse, ExecType instType, int numCols)
+ {
+ //rtplatform for MR
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( instType ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ try
+ {
+ //determine script and function name
+ String TEST_NAME = testname;
+ int fn = type.ordinal();
+ double sparsity = (sparse) ? sparsity1 : sparsity2;
+ String TEST_CACHE_DIR = TEST_CACHE_ENABLED ? TEST_NAME + type.ordinal() + "_" + sparsity + "/" : "";
+ boolean exceptionExpected = !TEST_NAME.equals(TEST_NAME1);
+
+ TestConfiguration config = getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config, TEST_CACHE_DIR);
+
+ // This is for running the junit test the new way, i.e., construct the arguments directly
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-explain","-args", input("A"), input("B"),
+ String.valueOf(fn), String.valueOf(numGroups), output("C") };
+ fullRScriptName = HOME + TEST_NAME + ".R";
+ rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + fn + " " + expectedDir();
+
+ //generate actual dataset
+ double[][] A = getRandomMatrix(rows, numCols, -0.05, 1, sparsity, 7);
+ writeInputMatrix("A", A, true);
+ MatrixCharacteristics mc1 = new MatrixCharacteristics(rows, numCols,1000,1000);
+ MapReduceTool.writeMetaDataFile(input("A.mtd"), ValueType.DOUBLE, mc1, OutputInfo.TextCellOutputInfo);
+ double[][] B = TestUtils.round(getRandomMatrix(rows, 1, 1, numGroups, 1.0, 3));
+ writeInputMatrix("B", B, true);
+ MatrixCharacteristics mc2 = new MatrixCharacteristics(rows,1,1000,1000);
+ MapReduceTool.writeMetaDataFile(input("B.mtd"), ValueType.DOUBLE, mc2, OutputInfo.TextCellOutputInfo);
+
+ //run tests
+ Class cla = (exceptionExpected ? DMLException.class : null);
+ runTest(true, exceptionExpected, cla, -1);
+
+ //compare matrices
+ if( !exceptionExpected ){
+ runRScript(true);
+
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("C");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("C");
+ TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+ }
+ }
+ catch(IOException ex)
+ {
+ ex.printStackTrace();
+ throw new RuntimeException(ex);
+ }
+ finally
+ {
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.R b/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.R
new file mode 100644
index 0000000..76e2d79
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.R
@@ -0,0 +1,70 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+
+library("Matrix")
+library("moments")
+
+A <- as.matrix(readMM(paste(args[1], "A.mtx", sep="")))
+B <- as.matrix(readMM(paste(args[1], "B.mtx", sep="")));
+fn = as.integer(args[2]);
+
+
+R = matrix(0,17,ncol(A));
+for( j in 1:ncol(A) )
+{
+Ai = A[,j];
+
+if( fn==0 )
+{
+ C = aggregate(as.vector(Ai), by=list(as.vector(B)), FUN=sum)[,2]
+}
+
+if( fn==1 )
+{
+ C = aggregate(as.vector(Ai), by=list(as.vector(B)), FUN=length)[,2]
+}
+
+if( fn==2 )
+{
+ C = aggregate(as.vector(Ai), by=list(as.vector(B)), FUN=mean)[,2]
+}
+
+if( fn==3 )
+{
+ C = aggregate(as.vector(Ai), by=list(as.vector(B)), FUN=var)[,2]
+}
+
+if( fn==4 )
+{
+ C = aggregate(as.vector(Ai), by=list(as.vector(B)), FUN=moment, order=3, central=TRUE)[,2]
+}
+
+if( fn==5 )
+{
+ C = aggregate(as.vector(Ai), by=list(as.vector(B)), FUN=moment, order=4, central=TRUE)[,2]
+}
+
+R[,j] = C;
+}
+
+writeMM(as(R, "CsparseMatrix"), paste(args[3], "C", sep=""));
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.dml b/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.dml
new file mode 100644
index 0000000..c4e70c8
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/GroupedAggregateMatrix.dml
@@ -0,0 +1,51 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = read($2);
+fn = $3;
+
+if( fn==0 )
+{
+ C = aggregate(target=A, groups=B, fn="sum", ngroups=$4);
+}
+else if( fn==1 )
+{
+ C = aggregate(target=A, groups=B, fn="count", ngroups=$4);
+}
+else if( fn==2 )
+{
+ C = aggregate(target=A, groups=B, fn="mean", ngroups=$4);
+}
+else if( fn==3 )
+{
+ C = aggregate(target=A, groups=B, fn="variance", ngroups=$4);
+}
+else if( fn==4 )
+{
+ C = aggregate(target=A, groups=B, fn="centralmoment", order="3", ngroups=$4);
+}
+else if( fn==5 )
+{
+ C = aggregate(target=A, groups=B, fn="centralmoment", order="4", ngroups=$4);
+}
+
+write(C, $5, format="text");
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f73569b4/src/test/scripts/functions/aggregate/GroupedAggregateMatrixNoDims.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/aggregate/GroupedAggregateMatrixNoDims.dml b/src/test/scripts/functions/aggregate/GroupedAggregateMatrixNoDims.dml
new file mode 100644
index 0000000..d92366d
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/GroupedAggregateMatrixNoDims.dml
@@ -0,0 +1,51 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = read($2);
+fn = $3;
+
+if( fn==0 )
+{
+ C = aggregate(target=A, groups=B, fn="sum");
+}
+else if( fn==1 )
+{
+ C = aggregate(target=A, groups=B, fn="count");
+}
+else if( fn==2 )
+{
+ C = aggregate(target=A, groups=B, fn="mean");
+}
+else if( fn==3 )
+{
+ C = aggregate(target=A, groups=B, fn="variance");
+}
+else if( fn==4 )
+{
+ C = aggregate(target=A, groups=B, fn="centralmoment", order="3");
+}
+else if( fn==5 )
+{
+ C = aggregate(target=A, groups=B, fn="centralmoment", order="4");
+}
+
+write(C, $5, format="text");
\ No newline at end of file
[4/4] incubator-systemml git commit: Fix grouped aggregate output
size, incl extended tests (caching, meta)
Posted by mb...@apache.org.
Fix grouped aggregate output size, incl extended tests (caching, meta)
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/7290510e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/7290510e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/7290510e
Branch: refs/heads/master
Commit: 7290510ec26ac2949b0d21d31eac768726929550
Parents: 112ba90
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Sat Dec 19 19:26:39 2015 +0100
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sat Dec 19 19:26:39 2015 +0100
----------------------------------------------------------------------
.../apache/sysml/hops/ParameterizedBuiltinOp.java | 15 ++++++---------
.../sysml/runtime/matrix/data/MatrixBlock.java | 4 ++--
.../sysml/test/integration/applications/ID3Test.java | 2 +-
.../aggregate/FullGroupedAggregateMatrixTest.java | 11 +++++++----
4 files changed, 16 insertions(+), 16 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7290510e/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
index 3a0a4f5..51d4ba7 100644
--- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
@@ -242,7 +242,8 @@ public class ParameterizedBuiltinOp extends Hop
Hop groups = getInput().get(_paramIndexMap.get(Statement.GAGG_GROUPS));
Lop append = null;
- if( target.getDim2()>=target.getColsInBlock() ) //multi-column-block result matrix
+ if( target.getDim2()>=target.getColsInBlock() // multi-column-block result matrix
+ || target.getDim2()<=0 ) // unkown
{
long m1_dim1 = target.getDim1();
long m1_dim2 = target.getDim2();
@@ -273,10 +274,8 @@ public class ParameterizedBuiltinOp extends Hop
}
else //single-column-block vector or matrix
{
- append = BinaryOp.constructMRAppendLop(
- target, groups,
- DataType.MATRIX, getValueType(), true,
- getInput().get(_paramIndexMap.get(Statement.GAGG_TARGET)));
+ append = BinaryOp.constructMRAppendLop(target, groups,
+ DataType.MATRIX, getValueType(), true, target);
}
// add the combine lop to parameter list, with a new name "combinedinput"
@@ -307,11 +306,9 @@ public class ParameterizedBuiltinOp extends Hop
}
GroupedAggregate grp_agg = new GroupedAggregate(inputlops, isWeighted, getDataType(), getValueType());
-
- // output dimensions are unknown at compilation time
grp_agg.getOutputParameters().setDimensions(outputDim1, outputDim2, getRowsInBlock(), getColsInBlock(), -1);
- grp_agg.setAllPositions(this.getBeginLine(), this.getBeginColumn(), this.getEndLine(), this.getEndColumn());
-
+ setLineNumbers(grp_agg);
+
setLops(grp_agg);
setRequiresReblock( true );
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7290510e/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index 0fe9095..ca569f3 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -5328,9 +5328,9 @@ public class MatrixBlock extends MatrixValue implements Externalizable
MatrixBlock result = checkType(ret);
boolean result_sparsity = estimateSparsityOnGroupedAgg(rlen, numGroups);
if(result==null)
- result=new MatrixBlock(numGroups, rowVector?1:target.getNumRows(), result_sparsity);
+ result=new MatrixBlock(numGroups, rowVector?1:target.getNumColumns(), result_sparsity);
else
- result.reset(numGroups, rowVector?1:target.getNumRows(), result_sparsity);
+ result.reset(numGroups, rowVector?1:target.getNumColumns(), result_sparsity);
//CM operator for count, mean, variance
//note: current support only for column vectors
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7290510e/src/test/java/org/apache/sysml/test/integration/applications/ID3Test.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/applications/ID3Test.java b/src/test/java/org/apache/sysml/test/integration/applications/ID3Test.java
index e80ce29..c7d3cb4 100644
--- a/src/test/java/org/apache/sysml/test/integration/applications/ID3Test.java
+++ b/src/test/java/org/apache/sysml/test/integration/applications/ID3Test.java
@@ -97,7 +97,7 @@ public abstract class ID3Test extends AutomatedTestBase
//run tests
//(changed expected MR from 62 to 66 because we now also count MR jobs in predicates)
//(changed expected MR from 66 to 68 because we now rewrite sum(v1*v2) to t(v1)%*%v2 which rarely creates more jobs due to MMCJ incompatibility of other operations)
- runTest(true, EXCEPTION_NOT_EXPECTED, null, 68); //max 68 compiled jobs
+ runTest(true, EXCEPTION_NOT_EXPECTED, null, 70); //max 68 compiled jobs
runRScript(true);
//check also num actually executed jobs
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7290510e/src/test/java/org/apache/sysml/test/integration/functions/aggregate/FullGroupedAggregateMatrixTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/FullGroupedAggregateMatrixTest.java b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/FullGroupedAggregateMatrixTest.java
index ec1f989..6a165f2 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/FullGroupedAggregateMatrixTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/FullGroupedAggregateMatrixTest.java
@@ -349,7 +349,7 @@ public class FullGroupedAggregateMatrixTest extends AutomatedTestBase
String TEST_NAME = testname;
int fn = type.ordinal();
double sparsity = (sparse) ? sparsity1 : sparsity2;
- String TEST_CACHE_DIR = TEST_CACHE_ENABLED ? TEST_NAME + type.ordinal() + "_" + sparsity + "/" : "";
+ String TEST_CACHE_DIR = TEST_CACHE_ENABLED ? TEST_NAME + type.ordinal() + "_" + sparsity + "_" + numCols + "/" : "";
boolean exceptionExpected = !TEST_NAME.equals(TEST_NAME1);
TestConfiguration config = getTestConfiguration(TEST_NAME);
@@ -379,11 +379,16 @@ public class FullGroupedAggregateMatrixTest extends AutomatedTestBase
//compare matrices
if( !exceptionExpected ){
+ //run R script for comparison
runRScript(true);
+ //compare output matrices
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("C");
HashMap<CellIndex, Double> rfile = readRMatrixFromFS("C");
TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+
+ //check dml output meta data
+ checkDMLMetaDataFile("C", new MatrixCharacteristics(numGroups,numCols,1,1));
}
}
catch(IOException ex)
@@ -397,6 +402,4 @@ public class FullGroupedAggregateMatrixTest extends AutomatedTestBase
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}
-
-
-}
\ No newline at end of file
+}