You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by du...@apache.org on 2016/01/12 23:41:50 UTC
[3/3] incubator-systemml git commit: [SYSTEMML-217] Add New Unary
Aggregates: Variance & Standard Deviation
[SYSTEMML-217] Add New Unary Aggregates: Variance & Standard Deviation
This introduces new "variance" and "standard deviation" unary operators to the DML language, including row and column variants for each: `var(X)`, `rowVars(X)`, `colVars(X)`, `sd(X)`, `rowSds(X)`, `colSds(X)`. The implementation includes parser, HOP, LOP, runtime, and optimizer components.
The following provides an slightly more-detailed overview, and a full discussion can be found in the design document on the associated JIRA issue:
* Adds `sd`, `colSds`, and `rowSds` built-in functions for standard deviation to the DML language.
* Adds `var`, `colVars`, and `rowVars` built-in functions for variance to the DML language.
* Creates a new `VAR` operation type for the existing `AggUnaryOp` HOP.
* Compiles the `sd`, `colSds`, and `rowSds` built-in DML functions to the square root of the variance via the `VAR` HOP type wrapped in a `OpOp1.SQRT` HOP type.
* Creates a new `Var` operation type for the existing `Aggregate` and `PartialAggregate` LOPs.
* Compiles the `VAR` HOP type to the `Var` LOP type.
* Creates new runtime opcodes: `"avar"`, `"uavar"`, `"uarvar"`, `"uacvar"`.
* Compiles the `Var` LOP type to new opcodes.
* Parses the new opcodes into runtime aggregate operators that use the `CM` runtime function with `Variance` type.
* Creates a new `VAR` aggregate type in the core unary aggregate library.
* Matches the runtime aggregate operators using the `CM` function of `Variance` type to the `VAR` case.
* Extends the core unary aggregate library with new variance functions for the `VAR` case.
* Adds logic for incremental aggregation of variances computed for multiple partitions.
* Extends existing optimizer rewrites for column-wise and row-wise aggregation to cover row and column variances.
* Adds new optimizer rewrites to optimize row and column variance operations under certain conditions.
Closes #32.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/3206ac14
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/3206ac14
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/3206ac14
Branch: refs/heads/master
Commit: 3206ac140f0e389ec0d5cabbb854168a098590ca
Parents: 5bc92d4
Author: Mike Dusenberry <mw...@us.ibm.com>
Authored: Tue Jan 12 14:40:13 2016 -0800
Committer: Mike Dusenberry <mw...@us.ibm.com>
Committed: Tue Jan 12 14:40:13 2016 -0800
----------------------------------------------------------------------
.../java/org/apache/sysml/hops/AggUnaryOp.java | 10 +-
src/main/java/org/apache/sysml/hops/Hop.java | 4 +-
.../hops/cost/CostEstimatorStaticRuntime.java | 10 +-
.../RewriteAlgebraicSimplificationDynamic.java | 93 +++-
.../java/org/apache/sysml/lops/Aggregate.java | 14 +-
.../org/apache/sysml/lops/PartialAggregate.java | 66 ++-
.../sysml/parser/BuiltinFunctionExpression.java | 18 +
.../org/apache/sysml/parser/DMLTranslator.java | 43 +-
.../org/apache/sysml/parser/Expression.java | 60 +--
.../sysml/runtime/functionobjects/CM.java | 81 +---
.../instructions/CPInstructionParser.java | 3 +
.../runtime/instructions/InstructionUtils.java | 41 +-
.../instructions/MRInstructionParser.java | 6 +-
.../instructions/SPInstructionParser.java | 3 +
.../runtime/instructions/cp/CM_COV_Object.java | 38 +-
.../instructions/mr/AggregateInstruction.java | 2 +-
.../sysml/runtime/matrix/data/LibMatrixAgg.java | 396 ++++++++++++++--
.../sysml/runtime/matrix/data/MatrixBlock.java | 207 +++++++-
.../matrix/data/OperationsOnMatrixValues.java | 12 +
.../functions/aggregate/ColStdDevsTest.java | 295 ++++++++++++
.../functions/aggregate/ColVariancesTest.java | 470 +++++++++++++++++++
.../functions/aggregate/RowStdDevsTest.java | 295 ++++++++++++
.../functions/aggregate/RowVariancesTest.java | 470 +++++++++++++++++++
.../functions/aggregate/StdDevTest.java | 294 ++++++++++++
.../functions/aggregate/VarianceTest.java | 295 ++++++++++++
.../scripts/functions/aggregate/ColStdDevs.R | 35 ++
.../scripts/functions/aggregate/ColStdDevs.dml | 24 +
.../scripts/functions/aggregate/ColVariances.R | 35 ++
.../functions/aggregate/ColVariances.dml | 24 +
.../scripts/functions/aggregate/RowStdDevs.R | 35 ++
.../scripts/functions/aggregate/RowStdDevs.dml | 24 +
.../scripts/functions/aggregate/RowVariances.R | 35 ++
.../functions/aggregate/RowVariances.dml | 24 +
src/test/scripts/functions/aggregate/StdDev.R | 29 ++
src/test/scripts/functions/aggregate/StdDev.dml | 24 +
src/test/scripts/functions/aggregate/Variance.R | 29 ++
.../scripts/functions/aggregate/Variance.dml | 24 +
.../functions/aggregate/ZPackageSuite.java | 8 +-
38 files changed, 3370 insertions(+), 206 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 86c2847..4c4d7c8 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -357,6 +357,13 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
else if( _direction == Direction.Row ) //(always dense)
val = OptimizerUtils.estimateSizeExactSparsity(dim1, 2, 1.0);
break;
+ case VAR:
+ //worst-case correction LASTFOURROWS / LASTFOURCOLUMNS
+ if( _direction == Direction.Col ) //(potentially sparse)
+ val = OptimizerUtils.estimateSizeExactSparsity(4, dim2, sparsity);
+ else if( _direction == Direction.Row ) //(always dense)
+ val = OptimizerUtils.estimateSizeExactSparsity(dim1, 4, 1.0);
+ break;
case MAXINDEX:
case MININDEX:
Hop hop = getInput().get(0);
@@ -700,7 +707,8 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
boolean ret = (_direction == Direction.RowCol) && //full aggregate
(_op == AggOp.SUM || _op == AggOp.SUM_SQ || //valid aggregration functions
_op == AggOp.MIN || _op == AggOp.MAX ||
- _op == AggOp.PROD || _op == AggOp.MEAN);
+ _op == AggOp.PROD || _op == AggOp.MEAN ||
+ _op == AggOp.VAR);
//note: trace and maxindex are not transpose-safe.
return ret;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index b3bed99..3d2c76b 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1064,7 +1064,7 @@ public abstract class Hop
public enum AggOp {
- SUM, SUM_SQ, MIN, MAX, TRACE, PROD, MEAN, MAXINDEX, MININDEX
+ SUM, SUM_SQ, MIN, MAX, TRACE, PROD, MEAN, VAR, MAXINDEX, MININDEX
};
public enum ReOrgOp {
@@ -1124,6 +1124,7 @@ public abstract class Hop
HopsAgg2Lops.put(AggOp.MININDEX, org.apache.sysml.lops.Aggregate.OperationTypes.MinIndex);
HopsAgg2Lops.put(AggOp.PROD, org.apache.sysml.lops.Aggregate.OperationTypes.Product);
HopsAgg2Lops.put(AggOp.MEAN, org.apache.sysml.lops.Aggregate.OperationTypes.Mean);
+ HopsAgg2Lops.put(AggOp.VAR, org.apache.sysml.lops.Aggregate.OperationTypes.Var);
}
protected static final HashMap<ReOrgOp, org.apache.sysml.lops.Transform.OperationTypes> HopsTransf2Lops;
@@ -1402,6 +1403,7 @@ public abstract class Hop
HopsAgg2String.put(AggOp.MININDEX, "minindex");
HopsAgg2String.put(AggOp.TRACE, "trace");
HopsAgg2String.put(AggOp.MEAN, "mean");
+ HopsAgg2String.put(AggOp.VAR, "var");
}
protected static final HashMap<Hop.ReOrgOp, String> HopsTransf2String;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java b/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java
index 15003c3..73203f0 100644
--- a/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java
+++ b/src/main/java/org/apache/sysml/hops/cost/CostEstimatorStaticRuntime.java
@@ -922,7 +922,7 @@ public class CostEstimatorStaticRuntime extends CostEstimator
return 6 * d1m * d1n; //2*1(*) + 4 (k+)
case AggregateUnary: //opcodes: uak+, uark+, uack+, uasqk+, uarsqk+, uacsqk+,
- // uamean, uarmean, uacmean,
+ // uamean, uarmean, uacmean, uavar, uarvar, uacvar,
// uamax, uarmax, uarimax, uacmax, uamin, uarmin, uacmin,
// ua+, uar+, uac+, ua*, uatrace, uaktrace,
// nrow, ncol, length, cm
@@ -952,11 +952,13 @@ public class CostEstimatorStaticRuntime extends CostEstimator
}
else if( optype.equals("uak+") || optype.equals("uark+") || optype.equals("uack+"))
return 4 * d1m * d1n; //1*k+
- else if( optype.equals("uasqk+") || optype.equals("uarsqk+") || optype.equals("uacsqk+"))
+ else if( optype.equals("uasqk+") || optype.equals("uarsqk+") || optype.equals("uacsqk+"))
return 5 * d1m * d1n; // +1 for multiplication to square term
else if( optype.equals("uamean") || optype.equals("uarmean") || optype.equals("uacmean"))
return 7 * d1m * d1n; //1*k+
- else if( optype.equals("uamax") || optype.equals("uarmax") || optype.equals("uacmax")
+ else if( optype.equals("uavar") || optype.equals("uarvar") || optype.equals("uacvar"))
+ return 14 * d1m * d1n;
+ else if( optype.equals("uamax") || optype.equals("uarmax") || optype.equals("uacmax")
|| optype.equals("uamin") || optype.equals("uarmin") || optype.equals("uacmin")
|| optype.equals("uarimax") || optype.equals("ua*") )
return d1m * d1n;
@@ -1195,6 +1197,8 @@ public class CostEstimatorStaticRuntime extends CostEstimator
return 4 * numMap * d1m * d1n * d1s;
else if( optype.equals("asqk+") )
return 5 * numMap * d1m * d1n * d1s; // +1 for multiplication to square term
+ else if( optype.equals("avar") )
+ return 14 * numMap * d1m * d1n * d1s;
else
return numMap * d1m * d1n * d1s;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 31c394b..4dd5f87 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -63,7 +63,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
private static final Log LOG = LogFactory.getLog(RewriteAlgebraicSimplificationDynamic.class.getName());
//valid aggregation operation types for rowOp to Op conversions (not all operations apply)
- private static AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN};
+ private static AggOp[] LOOKUP_VALID_ROW_COL_AGGREGATE = new AggOp[]{AggOp.SUM, AggOp.SUM_SQ, AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.VAR};
//valid aggregation operation types for empty (sparse-safe) operations (not all operations apply)
//AggOp.MEAN currently not due to missing count/corrections
@@ -177,7 +177,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
hi = reorderMinusMatrixMult(hop, hi, i); //e.g., (-t(X))%*%y->-(t(X)%*%y), TODO size
hi = simplifySumMatrixMult(hop, hi, i); //e.g., sum(A%*%B) -> sum(t(colSums(A))*rowSums(B)), if not dot product / wsloss
hi = simplifyEmptyBinaryOperation(hop, hi, i); //e.g., X*Y -> matrix(0,nrow(X), ncol(X)) / X+Y->X / X-Y -> X
- hi = simplifyScalarMVBinaryOperation(hi); //e.g., X*y -> X*as.scalar(y), if y is a 1-1 matrix
+ hi = simplifyScalarMVBinaryOperation(hi); //e.g., X*y -> X*as.scalar(y), if y is a 1-1 matrix
hi = simplifyNnzComputation(hop, hi, i); //e.g., sum(ppred(X,0,"!=")) -> literal(nnz(X)), if nnz known
//process childs recursively after rewrites (to investigate pattern newly created by rewrites)
@@ -543,7 +543,7 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
return hi;
}
-
+
/**
*
* @param parent
@@ -566,13 +566,38 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
{
if( input.getDim1() == 1 )
{
- //remove unnecessary col aggregation for 1 row
- HopRewriteUtils.removeChildReference(parent, hi);
- HopRewriteUtils.addChildReference(parent, input, pos);
- parent.refreshSizeInformation();
- hi = input;
-
- LOG.debug("Applied simplifyColwiseAggregate1");
+ if (uhi.getOp() == AggOp.VAR) {
+ // For the column variance aggregation, if the input is a row vector,
+ // the column variances will each be zero.
+ // Therefore, perform a rewrite from COLVAR(X) to a row vector of zeros.
+ Hop emptyRow = HopRewriteUtils.createDataGenOp(uhi, input, 0);
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.addChildReference(parent, emptyRow, pos);
+ parent.refreshSizeInformation();
+
+ // cleanup
+ if (hi.getParent().isEmpty())
+ HopRewriteUtils.removeAllChildReferences(hi);
+ if (input.getParent().isEmpty())
+ HopRewriteUtils.removeAllChildReferences(input);
+
+ // replace current HOP with new empty row HOP
+ hi = emptyRow;
+
+ LOG.debug("Applied simplifyColwiseAggregate for colVars");
+ } else {
+ // All other valid column aggregations over a row vector will result
+ // in the row vector itself.
+ // Therefore, remove unnecessary col aggregation for 1 row.
+ HopRewriteUtils.removeChildReference(parent, hi);
+ HopRewriteUtils.addChildReference(parent, input, pos);
+ parent.refreshSizeInformation();
+
+ // replace current HOP with input HOP
+ hi = input;
+
+ LOG.debug("Applied simplifyColwiseAggregate1");
+ }
}
else if( input.getDim2() == 1 )
{
@@ -599,13 +624,13 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
LOG.debug("Applied simplifyColwiseAggregate2");
}
- }
+ }
}
}
return hi;
}
-
+
/**
*
* @param parent
@@ -628,13 +653,39 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
{
if( input.getDim2() == 1 )
{
- //remove unnecessary row aggregation for 1 col
- HopRewriteUtils.removeChildReference(parent, hi);
- HopRewriteUtils.addChildReference(parent, input, pos);
- parent.refreshSizeInformation();
- hi = input;
-
- LOG.debug("Applied simplifyRowwiseAggregate1");
+ if (uhi.getOp() == AggOp.VAR) {
+ // For the row variance aggregation, if the input is a column vector,
+ // the row variances will each be zero.
+ // Therefore, perform a rewrite from ROWVAR(X) to a column vector of
+ // zeros.
+ Hop emptyCol = HopRewriteUtils.createDataGenOp(input, uhi, 0);
+ HopRewriteUtils.removeChildReferenceByPos(parent, hi, pos);
+ HopRewriteUtils.addChildReference(parent, emptyCol, pos);
+ parent.refreshSizeInformation();
+
+ // cleanup
+ if (hi.getParent().isEmpty())
+ HopRewriteUtils.removeAllChildReferences(hi);
+ if (input.getParent().isEmpty())
+ HopRewriteUtils.removeAllChildReferences(input);
+
+ // replace current HOP with new empty column HOP
+ hi = emptyCol;
+
+ LOG.debug("Applied simplifyRowwiseAggregate for rowVars");
+ } else {
+ // All other valid row aggregations over a column vector will result
+ // in the column vector itself.
+ // Therefore, remove unnecessary row aggregation for 1 col
+ HopRewriteUtils.removeChildReference(parent, hi);
+ HopRewriteUtils.addChildReference(parent, input, pos);
+ parent.refreshSizeInformation();
+
+ // replace current HOP with input HOP
+ hi = input;
+
+ LOG.debug("Applied simplifyRowwiseAggregate1");
+ }
}
else if( input.getDim1() == 1 )
{
@@ -661,8 +712,8 @@ public class RewriteAlgebraicSimplificationDynamic extends HopRewriteRule
LOG.debug("Applied simplifyRowwiseAggregate2");
}
- }
- }
+ }
+ }
}
return hi;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/lops/Aggregate.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/Aggregate.java b/src/main/java/org/apache/sysml/lops/Aggregate.java
index d7749ae..fc229a0 100644
--- a/src/main/java/org/apache/sysml/lops/Aggregate.java
+++ b/src/main/java/org/apache/sysml/lops/Aggregate.java
@@ -38,7 +38,7 @@ public class Aggregate extends Lop
/** Aggregate operation types **/
public enum OperationTypes {
- Sum, Product, Min, Max, Trace, KahanSum, KahanSumSq, KahanTrace, Mean,MaxIndex, MinIndex
+ Sum, Product, Min, Max, Trace, KahanSum, KahanSumSq, KahanTrace, Mean, Var, MaxIndex, MinIndex
}
OperationTypes operation;
@@ -83,7 +83,8 @@ public class Aggregate extends Lop
// this function must be invoked during hop-to-lop translation
public void setupCorrectionLocation(CorrectionLocationType loc) {
if (operation == OperationTypes.KahanSum || operation == OperationTypes.KahanSumSq
- || operation == OperationTypes.KahanTrace || operation == OperationTypes.Mean) {
+ || operation == OperationTypes.KahanTrace || operation == OperationTypes.Mean
+ || operation == OperationTypes.Var) {
isCorrectionUsed = true;
correctionLocation = loc;
}
@@ -115,7 +116,9 @@ public class Aggregate extends Lop
case Trace:
return "a+";
case Mean:
- return "amean";
+ return "amean";
+ case Var:
+ return "avar";
case Product:
return "a*";
case Min:
@@ -160,8 +163,9 @@ public class Aggregate extends Lop
boolean isCorrectionApplicable = false;
String opcode = getOpcode();
- if (operation == OperationTypes.Mean || operation == OperationTypes.KahanSum
- || operation == OperationTypes.KahanSumSq || operation == OperationTypes.KahanTrace)
+ if (operation == OperationTypes.Mean || operation == OperationTypes.Var
+ || operation == OperationTypes.KahanSum || operation == OperationTypes.KahanSumSq
+ || operation == OperationTypes.KahanTrace)
isCorrectionApplicable = true;
StringBuilder sb = new StringBuilder();
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/lops/PartialAggregate.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/PartialAggregate.java b/src/main/java/org/apache/sysml/lops/PartialAggregate.java
index 19e437a..fdcc5b0 100644
--- a/src/main/java/org/apache/sysml/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysml/lops/PartialAggregate.java
@@ -47,8 +47,10 @@ public class PartialAggregate extends Lop
LASTROW,
LASTCOLUMN,
LASTTWOROWS,
- LASTTWOCOLUMNS,
- INVALID
+ LASTTWOCOLUMNS,
+ LASTFOURROWS,
+ LASTFOURCOLUMNS,
+ INVALID
};
private Aggregate.OperationTypes operation;
@@ -212,7 +214,35 @@ public class PartialAggregate extends Lop
+ "Unknown aggregate direction: " + direction);
}
break;
-
+
+ case Var:
+ // Computation of stable variance requires each mapper to
+ // output the running variance, the running mean, the
+ // count, a correction term for the squared deviations
+ // from the sample mean (m2), and a correction term for
+ // the mean. These values collectively allow all other
+ // necessary intermediates to be reconstructed, and the
+ // variance will output by our unary aggregate framework.
+ // Thus, our outputs will be:
+ // { var | mean, count, m2 correction, mean correction }
+ switch (direction) {
+ case Col:
+ // colVars: { var | mean, count, m2 correction, mean correction },
+ // where each element is a column.
+ loc = CorrectionLocationType.LASTFOURROWS;
+ break;
+ case Row:
+ case RowCol:
+ // var, rowVars: { var | mean, count, m2 correction, mean correction },
+ // where each element is a row.
+ loc = CorrectionLocationType.LASTFOURCOLUMNS;
+ break;
+ default:
+ throw new LopsException("PartialAggregate.getCorrectionLocation() - "
+ + "Unknown aggregate direction: " + direction);
+ }
+ break;
+
case MaxIndex:
case MinIndex:
loc = CorrectionLocationType.LASTCOLUMN;
@@ -325,16 +355,6 @@ public class PartialAggregate extends Lop
break;
}
- case Mean: {
- if( dir == DirectionTypes.RowCol )
- return "uamean";
- else if( dir == DirectionTypes.Row )
- return "uarmean";
- else if( dir == DirectionTypes.Col )
- return "uacmean";
- break;
- }
-
case KahanSum: {
// instructions that use kahanSum are similar to ua+,uar+,uac+
// except that they also produce correction values along with partial
@@ -358,6 +378,26 @@ public class PartialAggregate extends Lop
break;
}
+ case Mean: {
+ if( dir == DirectionTypes.RowCol )
+ return "uamean";
+ else if( dir == DirectionTypes.Row )
+ return "uarmean";
+ else if( dir == DirectionTypes.Col )
+ return "uacmean";
+ break;
+ }
+
+ case Var: {
+ if( dir == DirectionTypes.RowCol )
+ return "uavar";
+ else if( dir == DirectionTypes.Row )
+ return "uarvar";
+ else if( dir == DirectionTypes.Col )
+ return "uacvar";
+ break;
+ }
+
case Product: {
if( dir == DirectionTypes.RowCol )
return "ua*";
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
index da5bcd9..d56aa08 100644
--- a/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysml/parser/BuiltinFunctionExpression.java
@@ -237,6 +237,8 @@ public class BuiltinFunctionExpression extends DataIdentifier
case COLMAX:
case COLMIN:
case COLMEAN:
+ case COLSD:
+ case COLVAR:
// colSums(X);
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
@@ -251,6 +253,8 @@ public class BuiltinFunctionExpression extends DataIdentifier
case ROWMIN:
case ROWINDEXMIN:
case ROWMEAN:
+ case ROWSD:
+ case ROWVAR:
//rowSums(X);
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
@@ -262,6 +266,8 @@ public class BuiltinFunctionExpression extends DataIdentifier
case SUM:
case PROD:
case TRACE:
+ case SD:
+ case VAR:
// sum(X);
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
@@ -1319,6 +1325,10 @@ public class BuiltinFunctionExpression extends DataIdentifier
bifop = Expression.BuiltinFunctionOp.SUM;
else if (functionName.equals("mean"))
bifop = Expression.BuiltinFunctionOp.MEAN;
+ else if (functionName.equals("sd"))
+ bifop = Expression.BuiltinFunctionOp.SD;
+ else if (functionName.equals("var"))
+ bifop = Expression.BuiltinFunctionOp.VAR;
else if (functionName.equals("trace"))
bifop = Expression.BuiltinFunctionOp.TRACE;
else if (functionName.equals("t"))
@@ -1353,6 +1363,14 @@ public class BuiltinFunctionExpression extends DataIdentifier
bifop = Expression.BuiltinFunctionOp.ROWMEAN;
else if (functionName.equals("colMeans"))
bifop = Expression.BuiltinFunctionOp.COLMEAN;
+ else if (functionName.equals("rowSds"))
+ bifop = Expression.BuiltinFunctionOp.ROWSD;
+ else if (functionName.equals("colSds"))
+ bifop = Expression.BuiltinFunctionOp.COLSD;
+ else if (functionName.equals("rowVars"))
+ bifop = Expression.BuiltinFunctionOp.ROWVAR;
+ else if (functionName.equals("colVars"))
+ bifop = Expression.BuiltinFunctionOp.COLVAR;
else if (functionName.equals("cummax"))
bifop = Expression.BuiltinFunctionOp.CUMMAX;
else if (functionName.equals("cummin"))
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/parser/DMLTranslator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/DMLTranslator.java b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
index e10f3a8..c643396 100644
--- a/src/main/java/org/apache/sysml/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysml/parser/DMLTranslator.java
@@ -2188,11 +2188,23 @@ public class DMLTranslator
break;
case COLMEAN:
- // hop to compute colSums
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.MEAN,
Direction.Col, expr);
break;
+ case COLSD:
+ // colStdDevs = sqrt(colVariances)
+ currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(),
+ target.getValueType(), AggOp.VAR, Direction.Col, expr);
+ currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),
+ target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
+ break;
+
+ case COLVAR:
+ currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(),
+ target.getValueType(), AggOp.VAR, Direction.Col, expr);
+ break;
+
case ROWSUM:
currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(), target.getValueType(), AggOp.SUM,
Direction.Row, expr);
@@ -2223,6 +2235,19 @@ public class DMLTranslator
Direction.Row, expr);
break;
+ case ROWSD:
+ // rowStdDevs = sqrt(rowVariances)
+ currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(),
+ target.getValueType(), AggOp.VAR, Direction.Row, expr);
+ currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),
+ target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
+ break;
+
+ case ROWVAR:
+ currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(),
+ target.getValueType(), AggOp.VAR, Direction.Row, expr);
+ break;
+
case NROW:
// If the dimensions are available at compile time, then create a LiteralOp (constant propagation)
// Else create a UnaryOp so that a control program instruction is generated
@@ -2283,7 +2308,20 @@ public class DMLTranslator
Hop.OpOp3.CENTRALMOMENT, expr, expr2, orderHop);
}
break;
-
+
+ case SD:
+ // stdDev = sqrt(variance)
+ currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(),
+ target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
+ currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(),
+ target.getValueType(), Hop.OpOp1.SQRT, currBuiltinOp);
+ break;
+
+ case VAR:
+ currBuiltinOp = new AggUnaryOp(target.getName(), target.getDataType(),
+ target.getValueType(), AggOp.VAR, Direction.RowCol, expr);
+ break;
+
case MIN:
//construct AggUnary for min(X) but BinaryOp for min(X,Y)
if( expr2 == null ) {
@@ -2295,6 +2333,7 @@ public class DMLTranslator
expr, expr2);
}
break;
+
case MAX:
//construct AggUnary for max(X) but BinaryOp for max(X,Y)
if( expr2 == null ) {
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/parser/Expression.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/parser/Expression.java b/src/main/java/org/apache/sysml/parser/Expression.java
index 7ec06aa..709a581 100644
--- a/src/main/java/org/apache/sysml/parser/Expression.java
+++ b/src/main/java/org/apache/sysml/parser/Expression.java
@@ -53,66 +53,72 @@ public abstract class Expression
ASIN,
ATAN,
AVG,
- CAST_AS_MATRIX,
- CAST_AS_SCALAR,
- CAST_AS_DOUBLE,
- CAST_AS_INT,
CAST_AS_BOOLEAN,
- COLMEAN,
+ CAST_AS_DOUBLE,
+ CAST_AS_INT,
+ CAST_AS_MATRIX,
+ CAST_AS_SCALAR,
+ CBIND, //previously APPEND
+ CEIL,
COLMAX,
- COLMIN,
+ COLMEAN,
+ COLMIN,
+ COLSD,
COLSUM,
+ COLVAR,
COS,
- COV,
+ COV,
CUMMAX,
CUMMIN,
CUMPROD,
CUMSUM,
DIAG,
+ EIGEN,
EXP,
- INTERQUANTILE,
- IQM,
+ FLOOR,
+ INTERQUANTILE,
+ INVERSE,
+ IQM,
LENGTH,
- LOG,
+ LOG,
+ LU,
MAX,
MEAN,
- MIN,
+ MEDIAN,
+ MIN,
MOMENT,
NCOL,
NROW,
OUTER,
PPRED,
PROD,
+ QR,
QUANTILE,
RANGE,
+ RBIND,
REV,
ROUND,
- ROWINDEXMAX,
+ ROWINDEXMAX,
+ ROWINDEXMIN,
ROWMAX,
ROWMEAN,
ROWMIN,
- ROWINDEXMIN,
- ROWSUM,
+ ROWSD,
+ ROWSUM,
+ ROWVAR,
+ SAMPLE,
+ SD,
SEQ,
- SIN,
+ SIN,
SIGN,
+ SOLVE,
SQRT,
- SUM,
+ SUM,
TABLE,
TAN,
TRACE,
TRANS,
- QR,
- LU,
- EIGEN,
- SOLVE,
- CEIL,
- FLOOR,
- CBIND, //previously APPEND
- RBIND,
- MEDIAN,
- INVERSE,
- SAMPLE
+ VAR
};
public enum ParameterizedBuiltinFunctionOp {
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/runtime/functionobjects/CM.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/CM.java b/src/main/java/org/apache/sysml/runtime/functionobjects/CM.java
index 085720f..10b6e0d 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/CM.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/CM.java
@@ -53,14 +53,14 @@ public class CM extends ValueFunction
switch( _type ) //helper obj on demand
{
- case COUNT:
+ case COUNT:
break;
case CM4:
case CM3:
_buff3 = new KahanObject(0, 0);
- case CM2:
- case VARIANCE:
+ case CM2:
_buff2 = new KahanObject(0, 0);
+ case VARIANCE:
case MEAN:
_plus = KahanPlus.getKahanPlusFnObject();
break;
@@ -80,6 +80,10 @@ public class CM extends ValueFunction
throw new CloneNotSupportedException();
}
+ public AggregateOperationTypes getAggOpType() {
+ return _type;
+ }
+
/**
* Special case for weights w2==1
*/
@@ -174,9 +178,7 @@ public class CM extends ValueFunction
cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
double t1=cm1.w/w*d;
double lt1=t1*d;
- _buff2.set(cm1.m2);
- _buff2=(KahanObject) _plus.execute(_buff2, lt1);
- cm1.m2.set(_buff2);
+ cm1.m2=(KahanObject) _plus.execute(cm1.m2, lt1);
cm1.w=w;
break;
}
@@ -282,9 +284,7 @@ public class CM extends ValueFunction
cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
double t1=cm1.w*w2/w*d;
double lt1=t1*d;
- _buff2.set(cm1.m2);
- _buff2=(KahanObject) _plus.execute(_buff2, lt1);
- cm1.m2.set(_buff2);
+ cm1.m2=(KahanObject) _plus.execute(cm1.m2, lt1);
cm1.w=w;
break;
}
@@ -295,27 +295,10 @@ public class CM extends ValueFunction
return cm1;
}
-
- /*
- //following the SPSS definition.
- public Data execute(Data in1, double in2, double w2) throws DMLRuntimeException {
- CMObject cm=(CMObject) in1;
- double oldweight=cm._weight;
- cm._weight+=w2;
- double v=w2/cm._weight*(in2-cm._mean);
- cm._mean+=v;
- double oldm2=cm._m2;
- double oldm3=cm._m3;
- double oldm4=cm._m4;
- double weightProduct=cm._weight*oldweight;
- double vsquare=Math.pow(v, 2);
- cm._m2=oldm2+weightProduct/w2*vsquare;
- cm._m3=oldm3-3*v*oldm2+weightProduct/Math.pow(w2,2)*(cm._weight-2*w2)*Math.pow(v, 3);
- cm._m4=oldm4-4*v*oldm3+6*vsquare*oldm2
- +((Math.pow(cm._weight, 2)-3*w2*oldweight)/Math.pow(w2,3))*Math.pow(v, 4)*weightProduct;
- return cm;
- }*/
+ /**
+ * Combining stats from two partitions of the data.
+ */
@Override
public Data execute(Data in1, Data in2) throws DMLRuntimeException
{
@@ -418,10 +401,8 @@ public class CM extends ValueFunction
cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
double t1=cm1.w*cm2.w/w*d;
double lt1=t1*d;
- _buff2.set(cm1.m2);
- _buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction);
- _buff2=(KahanObject) _plus.execute(_buff2, lt1);
- cm1.m2.set(_buff2);
+ cm1.m2=(KahanObject) _plus.execute(cm1.m2, cm2.m2._sum, cm2.m2._correction);
+ cm1.m2=(KahanObject) _plus.execute(cm1.m2, lt1);
cm1.w=w;
break;
}
@@ -432,38 +413,4 @@ public class CM extends ValueFunction
return cm1;
}
- /*
- private double Q(CMObject cm1, CMObject cm2, int power)
- {
- return cm1._weight*Math.pow(cm1._mean,power)+cm2._weight*Math.pow(cm2._mean,power);
- }
-
- //following the SPSS definition, it is wrong
- public Data execute(Data in1, Data in2) throws DMLRuntimeException
- {
- CMObject cm1=(CMObject) in1;
- CMObject cm2=(CMObject) in2;
- double w=cm1._weight+cm2._weight;
- double q1=cm1._mean*cm1._weight+cm2._mean*cm2._weight;
- double mean=q1/w;
- double p1=mean-cm1._mean;
- double p2=mean-cm2._mean;
- double q2=Q(cm1, cm2, 2);
- double q3=Q(cm1, cm2, 3);
- double q4=Q(cm1, cm2, 4);
- double mean2=Math.pow(mean, 2);
- double mean3=Math.pow(mean, 3);
- double mean4=Math.pow(mean, 4);
- double m2 = cm1._m2+cm2._m2 + q2 - 2*mean*q1 + w*mean2;
- double m3 = cm1._m3+cm2._m3 - 3*(p1*cm1._m2+p2*cm2._m2)
- - 3*mean*(Math.pow(cm1._mean, 2)+Math.pow(cm2._mean, 2)) + 4*q3 - w*mean3;
- double m4 = cm1._m4+cm2._m4 - 4*(p1*cm1._m3+p2*cm2._m3) + 6*(Math.pow(p1, 2)*cm1._m2+Math.pow(p2, 2)*cm2._m2)-4*q4-4*mean*q3+6*mean2*q2-4*mean3*q1+2*w*mean4;
- cm1._m2=m2;
- cm1._m3=m3;
- cm1._m4=m4;
- cm1._mean=mean;
- cm1._weight=w;
- return cm1;
- }*/
-
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
index 81b6550..c70f6a9 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/CPInstructionParser.java
@@ -83,6 +83,9 @@ public class CPInstructionParser extends InstructionParser
String2CPInstructionType.put( "uamean" , CPINSTRUCTION_TYPE.AggregateUnary);
String2CPInstructionType.put( "uarmean" , CPINSTRUCTION_TYPE.AggregateUnary);
String2CPInstructionType.put( "uacmean" , CPINSTRUCTION_TYPE.AggregateUnary);
+ String2CPInstructionType.put( "uavar" , CPINSTRUCTION_TYPE.AggregateUnary);
+ String2CPInstructionType.put( "uarvar" , CPINSTRUCTION_TYPE.AggregateUnary);
+ String2CPInstructionType.put( "uacvar" , CPINSTRUCTION_TYPE.AggregateUnary);
String2CPInstructionType.put( "uamax" , CPINSTRUCTION_TYPE.AggregateUnary);
String2CPInstructionType.put( "uarmax" , CPINSTRUCTION_TYPE.AggregateUnary);
String2CPInstructionType.put( "uarimax", CPINSTRUCTION_TYPE.AggregateUnary);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
index bec5ff7..6c5ec95 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/InstructionUtils.java
@@ -43,6 +43,7 @@ import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.functionobjects.And;
import org.apache.sysml.runtime.functionobjects.Builtin;
+import org.apache.sysml.runtime.functionobjects.CM;
import org.apache.sysml.runtime.functionobjects.Divide;
import org.apache.sysml.runtime.functionobjects.Equals;
import org.apache.sysml.runtime.functionobjects.GreaterThan;
@@ -77,6 +78,7 @@ import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysml.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysml.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
public class InstructionUtils
@@ -362,6 +364,27 @@ public class InstructionUtils
AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), true, CorrectionLocationType.LASTTWOROWS);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
+ else if ( opcode.equalsIgnoreCase("uavar") ) {
+ // Variance
+ CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
+ CorrectionLocationType cloc = CorrectionLocationType.LASTFOURCOLUMNS;
+ AggregateOperator agg = new AggregateOperator(0, varFn, true, cloc);
+ aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
+ }
+ else if ( opcode.equalsIgnoreCase("uarvar") ) {
+ // RowVariances
+ CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
+ CorrectionLocationType cloc = CorrectionLocationType.LASTFOURCOLUMNS;
+ AggregateOperator agg = new AggregateOperator(0, varFn, true, cloc);
+ aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
+ }
+ else if ( opcode.equalsIgnoreCase("uacvar") ) {
+ // ColVariances
+ CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
+ CorrectionLocationType cloc = CorrectionLocationType.LASTFOURROWS;
+ AggregateOperator agg = new AggregateOperator(0, varFn, true, cloc);
+ aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
+ }
else if ( opcode.equalsIgnoreCase("ua+") ) {
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
@@ -468,7 +491,15 @@ public class InstructionUtils
CorrectionLocationType lcorrLoc = (corrLoc==null) ? CorrectionLocationType.LASTTWOCOLUMNS : CorrectionLocationType.valueOf(corrLoc);
agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), lcorrExists, lcorrLoc);
}
-
+ else if ( opcode.equalsIgnoreCase("avar") ) {
+ boolean lcorrExists = (corrExists==null) ? true : Boolean.parseBoolean(corrExists);
+ CorrectionLocationType lcorrLoc = (corrLoc==null) ?
+ CorrectionLocationType.LASTFOURCOLUMNS :
+ CorrectionLocationType.valueOf(corrLoc);
+ CM varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);
+ agg = new AggregateOperator(0, varFn, lcorrExists, lcorrLoc);
+ }
+
return agg;
}
@@ -751,6 +782,8 @@ public class InstructionUtils
return "asqk+";
else if ( opcode.equalsIgnoreCase("uamean") || opcode.equalsIgnoreCase("uarmean") || opcode.equalsIgnoreCase("uacmean") )
return "amean";
+ else if ( opcode.equalsIgnoreCase("uavar") || opcode.equalsIgnoreCase("uarvar") || opcode.equalsIgnoreCase("uacvar") )
+ return "avar";
else if ( opcode.equalsIgnoreCase("ua+") || opcode.equalsIgnoreCase("uar+") || opcode.equalsIgnoreCase("uac+") )
return "a+";
else if ( opcode.equalsIgnoreCase("ua*") )
@@ -786,7 +819,11 @@ public class InstructionUtils
return CorrectionLocationType.LASTTWOCOLUMNS;
else if ( opcode.equalsIgnoreCase("uacmean") )
return CorrectionLocationType.LASTTWOROWS;
- else if (opcode.equalsIgnoreCase("uarimax") || opcode.equalsIgnoreCase("uarimin") )
+ else if ( opcode.equalsIgnoreCase("uavar") || opcode.equalsIgnoreCase("uarvar") )
+ return CorrectionLocationType.LASTFOURCOLUMNS;
+ else if ( opcode.equalsIgnoreCase("uacvar") )
+ return CorrectionLocationType.LASTFOURROWS;
+ else if (opcode.equalsIgnoreCase("uarimax") || opcode.equalsIgnoreCase("uarimin") )
return CorrectionLocationType.LASTCOLUMN;
return CorrectionLocationType.NONE;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
index 60555f5..993f0b1 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/MRInstructionParser.java
@@ -94,7 +94,8 @@ public class MRInstructionParser extends InstructionParser
String2MRInstructionType.put( "a*" , MRINSTRUCTION_TYPE.Aggregate);
String2MRInstructionType.put( "amax" , MRINSTRUCTION_TYPE.Aggregate);
String2MRInstructionType.put( "amin" , MRINSTRUCTION_TYPE.Aggregate);
- String2MRInstructionType.put( "amean" , MRINSTRUCTION_TYPE.Aggregate);
+ String2MRInstructionType.put( "amean" , MRINSTRUCTION_TYPE.Aggregate);
+ String2MRInstructionType.put( "avar" , MRINSTRUCTION_TYPE.Aggregate);
String2MRInstructionType.put( "arimax" , MRINSTRUCTION_TYPE.Aggregate);
String2MRInstructionType.put( "arimin" , MRINSTRUCTION_TYPE.Aggregate);
@@ -116,6 +117,9 @@ public class MRInstructionParser extends InstructionParser
String2MRInstructionType.put( "uamean", MRINSTRUCTION_TYPE.AggregateUnary);
String2MRInstructionType.put( "uarmean",MRINSTRUCTION_TYPE.AggregateUnary);
String2MRInstructionType.put( "uacmean",MRINSTRUCTION_TYPE.AggregateUnary);
+ String2MRInstructionType.put( "uavar", MRINSTRUCTION_TYPE.AggregateUnary);
+ String2MRInstructionType.put( "uarvar", MRINSTRUCTION_TYPE.AggregateUnary);
+ String2MRInstructionType.put( "uacvar", MRINSTRUCTION_TYPE.AggregateUnary);
String2MRInstructionType.put( "ua*" , MRINSTRUCTION_TYPE.AggregateUnary);
String2MRInstructionType.put( "uamax" , MRINSTRUCTION_TYPE.AggregateUnary);
String2MRInstructionType.put( "uamin" , MRINSTRUCTION_TYPE.AggregateUnary);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
index a64fc8f..2694fc9 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
@@ -92,6 +92,9 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "uamean" , SPINSTRUCTION_TYPE.AggregateUnary);
String2SPInstructionType.put( "uarmean" , SPINSTRUCTION_TYPE.AggregateUnary);
String2SPInstructionType.put( "uacmean" , SPINSTRUCTION_TYPE.AggregateUnary);
+ String2SPInstructionType.put( "uavar" , SPINSTRUCTION_TYPE.AggregateUnary);
+ String2SPInstructionType.put( "uarvar" , SPINSTRUCTION_TYPE.AggregateUnary);
+ String2SPInstructionType.put( "uacvar" , SPINSTRUCTION_TYPE.AggregateUnary);
String2SPInstructionType.put( "uamax" , SPINSTRUCTION_TYPE.AggregateUnary);
String2SPInstructionType.put( "uarmax" , SPINSTRUCTION_TYPE.AggregateUnary);
String2SPInstructionType.put( "uarimax", SPINSTRUCTION_TYPE.AggregateUnary);
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/runtime/instructions/cp/CM_COV_Object.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/CM_COV_Object.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/CM_COV_Object.java
index 8d7c661..4336392 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/CM_COV_Object.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/CM_COV_Object.java
@@ -125,14 +125,35 @@ public class CM_COV_Object extends Data
{
return w==0 && mean.isAllZero() && mean_v.isAllZero() && c2.isAllZero() ;
}
-
+
+ /**
+ * Return the result of the aggregated operation given the
+ * operator.
+ */
public double getRequiredResult(Operator op) throws DMLRuntimeException
{
if(op instanceof CMOperator)
{
AggregateOperationTypes agg=((CMOperator)op).aggOpType;
- switch(agg)
- {
+ return getRequiredResult(agg);
+ }
+ else
+ {
+ //avoid division by 0
+ if(w==1.0)
+ return 0;
+ else
+ return c2._sum/(w-1.0);
+ }
+ }
+
+ /**
+ * Return the result of the aggregated operation given the
+ * operation type.
+ */
+ public double getRequiredResult(AggregateOperationTypes agg) throws DMLRuntimeException {
+ switch(agg)
+ {
case COUNT:
return w;
case MEAN:
@@ -147,18 +168,9 @@ public class CM_COV_Object extends Data
return w==1.0? 0:m2._sum/(w-1);
default:
throw new DMLRuntimeException("Invalid aggreagte in CM_CV_Object: " + agg);
- }
- }
- else
- {
- //avoid division by 0
- if(w==1.0)
- return 0;
- else
- return c2._sum/(w-1.0);
}
}
-
+
/**
*
* @param op
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateInstruction.java
index 06708b7..f030904 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateInstruction.java
@@ -51,7 +51,7 @@ public class AggregateInstruction extends UnaryMRInstructionBase
AggregateOperator agg = null;
if(opcode.equalsIgnoreCase("ak+") || opcode.equalsIgnoreCase("asqk+")
- || opcode.equalsIgnoreCase("amean")) {
+ || opcode.equalsIgnoreCase("amean") || opcode.equalsIgnoreCase("avar")) {
InstructionUtils.checkNumFields ( str, 4 );
agg = InstructionUtils.parseAggregateOperator(opcode, parts[3], parts[4]);
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
index ef45da6..4a487d9 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
@@ -63,9 +63,10 @@ import org.apache.sysml.runtime.util.UtilFunctions;
* in order to prevent unnecessary worse asymptotic behavior.
*
* This library currently covers the following opcodes:
- * ak+, uak+, uark+, uack+, uamin, uarmin, uacmin, uamax, uarmax, uacmax,
- * ua*, uamean, uarmean, uacmean, uarimax, uaktrace.
- * cumk+, cummin, cummax, cum*, tak+
+ * ak+, uak+, uark+, uack+, uasqk+, uarsqk+, uacsqk+,
+ * uamin, uarmin, uacmin, uamax, uarmax, uacmax,
+ * ua*, uamean, uarmean, uacmean, uavar, uarvar, uacvar,
+ * uarimax, uaktrace, cumk+, cummin, cummax, cum*, tak+.
*
* TODO next opcode extensions: a+, colindexmax
*/
@@ -90,6 +91,7 @@ public class LibMatrixAgg
MIN,
MAX,
MEAN,
+ VAR,
MAX_INDEX,
MIN_INDEX,
PROD,
@@ -544,13 +546,23 @@ public class LibMatrixAgg
{
return AggType.MEAN;
}
-
+
+ //variance
+ if( vfn instanceof CM
+ && ((CM) vfn).getAggOpType() == AggregateOperationTypes.VARIANCE
+ && (op.aggOp.correctionLocation == CorrectionLocationType.LASTFOURCOLUMNS ||
+ op.aggOp.correctionLocation == CorrectionLocationType.LASTFOURROWS)
+ && (ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow) )
+ {
+ return AggType.VAR;
+ }
+
//prod
if( vfn instanceof Multiply && ifn instanceof ReduceAll )
{
return AggType.PROD;
}
-
+
//min / max
if( vfn instanceof Builtin &&
(ifn instanceof ReduceAll || ifn instanceof ReduceCol || ifn instanceof ReduceRow) )
@@ -1396,17 +1408,28 @@ public class LibMatrixAgg
case MEAN: //MEAN
{
KahanObject kbuff = new KahanObject(0, 0);
-
+
if( ixFn instanceof ReduceAll ) // MEAN
d_uamean(a, c, m, n, kbuff, (Mean)vFn, rl, ru);
else if( ixFn instanceof ReduceCol ) //ROWMEAN
d_uarmean(a, c, m, n, kbuff, (Mean)vFn, rl, ru);
else if( ixFn instanceof ReduceRow ) //COLMEAN
d_uacmean(a, c, m, n, kbuff, (Mean)vFn, rl, ru);
-
break;
}
- case PROD: //PROD
+ case VAR: //VAR
+ {
+ CM_COV_Object cbuff = new CM_COV_Object();
+
+ if( ixFn instanceof ReduceAll ) //VAR
+ d_uavar(a, c, m, n, cbuff, (CM)vFn, rl, ru);
+ else if( ixFn instanceof ReduceCol ) //ROWVAR
+ d_uarvar(a, c, m, n, cbuff, (CM)vFn, rl, ru);
+ else if( ixFn instanceof ReduceRow ) //COLVAR
+ d_uacvar(a, c, m, n, cbuff, (CM)vFn, rl, ru);
+ break;
+ }
+ case PROD: //PROD
{
if( ixFn instanceof ReduceAll ) // PROD
d_uam(a, c, m, n, rl, ru );
@@ -1521,10 +1544,21 @@ public class LibMatrixAgg
s_uarmean(a, c, m, n, kbuff, (Mean)vFn, rl, ru);
else if( ixFn instanceof ReduceRow ) //COLMEAN
s_uacmean(a, c, m, n, kbuff, (Mean)vFn, rl, ru);
-
break;
}
- case PROD: //PROD
+ case VAR: //VAR
+ {
+ CM_COV_Object cbuff = new CM_COV_Object();
+
+ if( ixFn instanceof ReduceAll ) //VAR
+ s_uavar(a, c, m, n, cbuff, (CM)vFn, rl, ru);
+ else if( ixFn instanceof ReduceCol ) //ROWVAR
+ s_uarvar(a, c, m, n, cbuff, (CM)vFn, rl, ru);
+ else if( ixFn instanceof ReduceRow ) //COLVAR
+ s_uacvar(a, c, m, n, cbuff, (CM)vFn, rl, ru);
+ break;
+ }
+ case PROD: //PROD
{
if( ixFn instanceof ReduceAll ) // PROD
s_uam(a, c, m, n, rl, ru );
@@ -1588,7 +1622,20 @@ public class LibMatrixAgg
out.quickSetValue(1, j, in.rlen); //count
break;
}
-
+ case VAR:
+ {
+ // results: { var | mean, count, m2 correction, mean correction }
+ if( ixFn instanceof ReduceAll ) //VAR
+ out.quickSetValue(0, 2, in.rlen*in.clen); //count
+ else if( ixFn instanceof ReduceCol ) //ROWVAR
+ for( int i=0; i<in.rlen; i++ )
+ out.quickSetValue(i, 2, in.clen); //count
+ else if( ixFn instanceof ReduceRow ) //COLVAR
+ for( int j=0; j<in.clen; j++ )
+ out.quickSetValue(2, j, in.rlen); //count
+ break;
+ }
+
default:
throw new DMLRuntimeException("Unsupported aggregation type: "+optype);
}
@@ -1963,8 +2010,91 @@ public class LibMatrixAgg
for( int i=rl, aix=rl*n; i<ru; i++, aix+=n )
meanAgg( a, c, aix, 0, n, kbuff, kmean );
}
-
-
+
+ /**
+ * VAR, opcode: uavar, dense input.
+ *
+ * @param a Array of values.
+ * @param c Output array to store variance, mean, count,
+ * m2 correction factor, and mean correction factor.
+ * @param m Number of rows.
+ * @param n Number of values per row.
+ * @param cbuff A CM_COV_Object to hold various intermediate
+ * values for the variance calculation.
+ * @param cm A CM object of type Variance to perform the variance
+ * calculation.
+ * @param rl Lower row limit.
+ * @param ru Upper row limit.
+ */
+ private static void d_uavar(double[] a, double[] c, int m, int n, CM_COV_Object cbuff, CM cm,
+ int rl, int ru) throws DMLRuntimeException
+ {
+ int len = Math.min((ru-rl)*n, a.length);
+ var(a, rl*n, len, cbuff, cm);
+ // store results: { var | mean, count, m2 correction, mean correction }
+ c[0] = cbuff.getRequiredResult(AggregateOperationTypes.VARIANCE);
+ c[1] = cbuff.mean._sum;
+ c[2] = cbuff.w;
+ c[3] = cbuff.m2._correction;
+ c[4] = cbuff.mean._correction;
+ }
+
+ /**
+ * ROWVAR, opcode: uarvar, dense input.
+ *
+ * @param a Array of values.
+ * @param c Output array to store variance, mean, count,
+ * m2 correction factor, and mean correction factor
+ * for each row.
+ * @param m Number of rows.
+ * @param n Number of values per row.
+ * @param cbuff A CM_COV_Object to hold various intermediate
+ * values for the variance calculation.
+ * @param cm A CM object of type Variance to perform the variance
+ * calculation.
+ * @param rl Lower row limit.
+ * @param ru Upper row limit.
+ */
+ private static void d_uarvar(double[] a, double[] c, int m, int n, CM_COV_Object cbuff, CM cm,
+ int rl, int ru) throws DMLRuntimeException
+ {
+ // calculate variance for each row
+ for (int i=rl, aix=rl*n, cix=rl*5; i<ru; i++, aix+=n, cix+=5) {
+ cbuff.reset(); // reset buffer for each row
+ var(a, aix, n, cbuff, cm);
+ // store row results: { var | mean, count, m2 correction, mean correction }
+ c[cix] = cbuff.getRequiredResult(AggregateOperationTypes.VARIANCE);
+ c[cix+1] = cbuff.mean._sum;
+ c[cix+2] = cbuff.w;
+ c[cix+3] = cbuff.m2._correction;
+ c[cix+4] = cbuff.mean._correction;
+ }
+ }
+
+ /**
+ * COLVAR, opcode: uacvar, dense input.
+ *
+ * @param a Array of values.
+ * @param c Output array to store variance, mean, count,
+ * m2 correction factor, and mean correction factor
+ * for each column.
+ * @param m Number of rows.
+ * @param n Number of values per row.
+ * @param cbuff A CM_COV_Object to hold various intermediate
+ * values for the variance calculation.
+ * @param cm A CM object of type Variance to perform the variance
+ * calculation.
+ * @param rl Lower row limit.
+ * @param ru Upper row limit.
+ */
+ private static void d_uacvar(double[] a, double[] c, int m, int n, CM_COV_Object cbuff, CM cm,
+ int rl, int ru) throws DMLRuntimeException
+ {
+ // calculate variance for each column incrementally
+ for (int i=rl, aix=rl*n; i<ru; i++, aix+=n)
+ varAgg(a, c, aix, 0, n, cbuff, cm);
+ }
+
/**
* PROD, opcode: ua*, dense input.
*
@@ -2535,7 +2665,7 @@ public class LibMatrixAgg
c[1] = len;
c[2] = kbuff._correction;
}
-
+
/**
* ROWMEAN, opcode: uarmean, sparse input.
*
@@ -2619,7 +2749,137 @@ public class LibMatrixAgg
}
}
}
-
+
+ /**
+ * VAR, opcode: uavar, sparse input.
+ *
+ * @param a Sparse array of values.
+ * @param c Output array to store variance, mean, count,
+ * m2 correction factor, and mean correction factor.
+ * @param m Number of rows.
+ * @param n Number of values per row.
+ * @param cbuff A CM_COV_Object to hold various intermediate
+ * values for the variance calculation.
+ * @param cm A CM object of type Variance to perform the variance
+ * calculation.
+ * @param rl Lower row limit.
+ * @param ru Upper row limit.
+ */
+ private static void s_uavar(SparseRow[] a, double[] c, int m, int n, CM_COV_Object cbuff, CM cm,
+ int rl, int ru) throws DMLRuntimeException
+ {
+ // compute and store count of empty cells before aggregation
+ int count = 0;
+ for (int i=rl; i<ru; i++)
+ count += (a[i]==null) ? n : n-a[i].size();
+ cbuff.w = count;
+
+ // calculate aggregated variance (only using non-empty cells)
+ for (int i=rl; i<ru; i++) {
+ SparseRow arow = a[i];
+ if (arow!=null && !arow.isEmpty()) {
+ int alen = arow.size();
+ double[] avals = arow.getValueContainer();
+ var(avals, 0, alen, cbuff, cm);
+ }
+ }
+
+ // store results: { var | mean, count, m2 correction, mean correction }
+ c[0] = cbuff.getRequiredResult(AggregateOperationTypes.VARIANCE);
+ c[1] = cbuff.mean._sum;
+ c[2] = cbuff.w;
+ c[3] = cbuff.m2._correction;
+ c[4] = cbuff.mean._correction;
+ }
+
+ /**
+ * ROWVAR, opcode: uarvar, sparse input.
+ *
+ * @param a Sparse array of values.
+ * @param c Output array to store variance, mean, count,
+ * m2 correction factor, and mean correction factor.
+ * @param m Number of rows.
+ * @param n Number of values per row.
+ * @param cbuff A CM_COV_Object to hold various intermediate
+ * values for the variance calculation.
+ * @param cm A CM object of type Variance to perform the variance
+ * calculation.
+ * @param rl Lower row limit.
+ * @param ru Upper row limit.
+ */
+ private static void s_uarvar(SparseRow[] a, double[] c, int m, int n, CM_COV_Object cbuff, CM cm,
+ int rl, int ru) throws DMLRuntimeException
+ {
+ // calculate aggregated variance for each row
+ for (int i=rl, cix=rl*5; i<ru; i++, cix+=5) {
+ cbuff.reset(); // reset buffer for each row
+
+ // compute and store count of empty cells in this row
+ // before aggregation
+ int count = (a[i] == null) ? n : n-a[i].size();
+ cbuff.w = count;
+
+ SparseRow arow = a[i];
+ if (arow != null && !arow.isEmpty()) {
+ int alen = arow.size();
+ double[] avals = arow.getValueContainer();
+ var(avals, 0, alen, cbuff, cm);
+ }
+
+ // store results: { var | mean, count, m2 correction, mean correction }
+ c[cix] = cbuff.getRequiredResult(AggregateOperationTypes.VARIANCE);
+ c[cix+1] = cbuff.mean._sum;
+ c[cix+2] = cbuff.w;
+ c[cix+3] = cbuff.m2._correction;
+ c[cix+4] = cbuff.mean._correction;
+ }
+ }
+
+ /**
+ * COLVAR, opcode: uacvar, sparse input.
+ *
+ * @param a Sparse array of values.
+ * @param c Output array to store variance, mean, count,
+ * m2 correction factor, and mean correction factor.
+ * @param m Number of rows.
+ * @param n Number of values per row.
+ * @param cbuff A CM_COV_Object to hold various intermediate
+ * values for the variance calculation.
+ * @param cm A CM object of type Variance to perform the variance
+ * calculation.
+ * @param rl Lower row limit.
+ * @param ru Upper row limit.
+ */
+ private static void s_uacvar(SparseRow[] a, double[] c, int m, int n, CM_COV_Object cbuff, CM cm,
+ int rl, int ru) throws DMLRuntimeException
+ {
+ // compute and store counts of empty cells per column before aggregation
+ // note: column results are { var | mean, count, m2 correction, mean correction }
+ // - first, store total possible column counts in 3rd row of output
+ Arrays.fill(c, n*2, n*3, ru-rl); // counts stored in 3rd row
+ // - then subtract one from the column count for each dense value in the column
+ for (int i=rl; i<ru; i++) {
+ SparseRow arow = a[i];
+ if (arow!=null && !arow.isEmpty()) {
+ int alen = arow.size();
+ double[] avals = arow.getValueContainer();
+ int[] aix = arow.getIndexContainer();
+ countDisAgg(avals, c, aix, n*2, alen); // counts stored in 3rd row
+ }
+ }
+
+ // calculate aggregated variance for each column
+ for (int i=rl; i<ru; i++) {
+ SparseRow arow = a[i];
+ if (arow != null && !arow.isEmpty()) {
+ int alen = arow.size();
+ double[] avals = arow.getValueContainer();
+ int[] aix = arow.getIndexContainer();
+ varAgg(avals, c, aix, alen, n, cbuff, cm);
+ }
+ }
+ }
+
/**
* PROD, opcode: ua*, sparse input.
*
@@ -2863,18 +3123,7 @@ public class LibMatrixAgg
mean.execute2(kbuff, a[ai], count+1);
}
}
-
- /*
- private static void mean( final double aval, final int len, int count, KahanObject kbuff, KahanPlus kplus )
- {
- for( int i=0; i<len; i++, count++ )
- {
- //delta: (newvalue-buffer._sum)/count
- kplus.execute2(kbuff, (aval-kbuff._sum)/(count+1));
- }
- }
- */
-
+
/**
*
* @param a
@@ -2921,7 +3170,96 @@ public class LibMatrixAgg
c[ai[i]+2*n] = kbuff._correction;
}
}
-
+
+ /**
+ * Variance
+ *
+ * @param a Array of values to sum.
+ * @param ai Index at which to start processing.
+ * @param len Number of values to process, starting at index ai.
+ * @param cbuff A CM_COV_Object to hold various intermediate
+ * values for the variance calculation.
+ * @param cm A CM object of type Variance to perform the variance
+ * calculation.
+ */
+ private static void var(double[] a, int ai, final int len, CM_COV_Object cbuff, CM cm)
+ throws DMLRuntimeException
+ {
+ for(int i=0; i<len; i++, ai++)
+ cbuff = (CM_COV_Object) cm.execute(cbuff, a[ai]);
+ }
+
+ /**
+ * Aggregated variance
+ *
+ * @param a Array of values to sum.
+ * @param c Output array to store aggregated sum and correction
+ * factors.
+ * @param ai Index at which to start processing array `a`.
+ * @param ci Index at which to start storing aggregated results
+ * into array `c`.
+ * @param len Number of values to process, starting at index ai.
+ * @param cbuff A CM_COV_Object to hold various intermediate
+ * values for the variance calculation.
+ * @param cm A CM object of type Variance to perform the variance
+ * calculation.
+ */
+ private static void varAgg(double[] a, double[] c, int ai, int ci, final int len,
+ CM_COV_Object cbuff, CM cm) throws DMLRuntimeException
+ {
+ for (int i=0; i<len; i++, ai++, ci++) {
+ // extract current values: { var | mean, count, m2 correction, mean correction }
+ cbuff.w = c[ci+2*len]; // count
+ cbuff.m2._sum = c[ci] * (cbuff.w - 1); // m2 = var * (n - 1)
+ cbuff.mean._sum = c[ci+len]; // mean
+ cbuff.m2._correction = c[ci+3*len];
+ cbuff.mean._correction = c[ci+4*len];
+ // calculate incremental aggregated variance
+ cbuff = (CM_COV_Object) cm.execute(cbuff, a[ai]);
+ // store updated values: { var | mean, count, m2 correction, mean correction }
+ c[ci] = cbuff.getRequiredResult(AggregateOperationTypes.VARIANCE);
+ c[ci+len] = cbuff.mean._sum;
+ c[ci+2*len] = cbuff.w;
+ c[ci+3*len] = cbuff.m2._correction;
+ c[ci+4*len] = cbuff.mean._correction;
+ }
+ }
+
+ /**
+ * Aggregated variance
+ *
+ * @param a Array of values to sum.
+ * @param c Output array to store aggregated sum and correction
+ * factors.
+ * @param ai Array of indices to process for array `a`.
+ * @param len Number of indices in `ai` to process.
+ * @param n Number of values per row.
+ * @param cbuff A CM_COV_Object to hold various intermediate
+ * values for the variance calculation.
+ * @param cm A CM object of type Variance to perform the variance
+ * calculation.
+ */
+ private static void varAgg(double[] a, double[] c, int[] ai, final int len, final int n,
+ CM_COV_Object cbuff, CM cm) throws DMLRuntimeException
+ {
+ for (int i=0; i<len; i++) {
+ // extract current values: { var | mean, count, m2 correction, mean correction }
+ cbuff.w = c[ai[i]+2*n]; // count
+ cbuff.m2._sum = c[ai[i]] * (cbuff.w - 1); // m2 = var * (n - 1)
+ cbuff.mean._sum = c[ai[i]+n]; // mean
+ cbuff.m2._correction = c[ai[i]+3*n];
+ cbuff.mean._correction = c[ai[i]+4*n];
+ // calculate incremental aggregated variance
+ cbuff = (CM_COV_Object) cm.execute(cbuff, a[i]);
+ // store updated values: { var | mean, count, m2 correction, mean correction }
+ c[ai[i]] = cbuff.getRequiredResult(AggregateOperationTypes.VARIANCE);
+ c[ai[i]+n] = cbuff.mean._sum;
+ c[ai[i]+2*n] = cbuff.w;
+ c[ai[i]+3*n] = cbuff.m2._correction;
+ c[ai[i]+4*n] = cbuff.mean._correction;
+ }
+ }
+
/**
* Meant for builtin function ops (min, max)
*
@@ -3072,7 +3410,7 @@ public class LibMatrixAgg
private static abstract class AggTask implements Callable<Object> {}
/**
- *
+ *
*
*/
private static class RowAggTask extends AggTask
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index 720aaaa..f2c4030 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -44,6 +44,7 @@ import org.apache.sysml.parser.DMLTranslator;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.DMLUnsupportedOperationException;
import org.apache.sysml.runtime.functionobjects.Builtin;
+import org.apache.sysml.runtime.functionobjects.CM;
import org.apache.sysml.runtime.functionobjects.CTable;
import org.apache.sysml.runtime.functionobjects.DiagIndex;
import org.apache.sysml.runtime.functionobjects.Divide;
@@ -71,6 +72,7 @@ import org.apache.sysml.runtime.matrix.operators.AggregateOperator;
import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.runtime.matrix.operators.CMOperator;
+import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
import org.apache.sysml.runtime.matrix.operators.COVOperator;
import org.apache.sysml.runtime.matrix.operators.Operator;
import org.apache.sysml.runtime.matrix.operators.QuaternaryOperator;
@@ -3288,6 +3290,82 @@ public class MatrixBlock extends MatrixValue implements Externalizable
cor.quickSetValue(r, 1, buffer._correction);
}
}
+ else if (aggOp.correctionLocation == CorrectionLocationType.LASTFOURROWS
+ && aggOp.increOp.fn instanceof CM
+ && ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
+ // create buffers to store results
+ CM_COV_Object cbuff_curr = new CM_COV_Object();
+ CM_COV_Object cbuff_part = new CM_COV_Object();
+
+ // perform incremental aggregation
+ for (int r=0; r<rlen; r++)
+ for (int c=0; c<clen; c++) {
+ // extract current values: { var | mean, count, m2 correction, mean correction }
+ // note: m2 = var * (n - 1)
+ cbuff_curr.w = cor.quickGetValue(1, c); // count
+ cbuff_curr.m2._sum = quickGetValue(r, c) * (cbuff_curr.w - 1); // m2
+ cbuff_curr.mean._sum = cor.quickGetValue(0, c); // mean
+ cbuff_curr.m2._correction = cor.quickGetValue(2, c);
+ cbuff_curr.mean._correction = cor.quickGetValue(3, c);
+
+ // extract partial values: { var | mean, count, m2 correction, mean correction }
+ // note: m2 = var * (n - 1)
+ cbuff_part.w = newWithCor.quickGetValue(r+2, c); // count
+ cbuff_part.m2._sum = newWithCor.quickGetValue(r, c) * (cbuff_part.w - 1); // m2
+ cbuff_part.mean._sum = newWithCor.quickGetValue(r+1, c); // mean
+ cbuff_part.m2._correction = newWithCor.quickGetValue(r+3, c);
+ cbuff_part.mean._correction = newWithCor.quickGetValue(r+4, c);
+
+ // calculate incremental aggregated variance
+ cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);
+
+ // store updated values: { var | mean, count, m2 correction, mean correction }
+ double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
+ quickSetValue(r, c, var);
+ cor.quickSetValue(0, c, cbuff_curr.mean._sum); // mean
+ cor.quickSetValue(1, c, cbuff_curr.w); // count
+ cor.quickSetValue(2, c, cbuff_curr.m2._correction);
+ cor.quickSetValue(3, c, cbuff_curr.mean._correction);
+ }
+ }
+ else if (aggOp.correctionLocation == CorrectionLocationType.LASTFOURCOLUMNS
+ && aggOp.increOp.fn instanceof CM
+ && ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
+ // create buffers to store results
+ CM_COV_Object cbuff_curr = new CM_COV_Object();
+ CM_COV_Object cbuff_part = new CM_COV_Object();
+
+ // perform incremental aggregation
+ for (int r=0; r<rlen; r++)
+ for (int c=0; c<clen; c++) {
+ // extract current values: { var | mean, count, m2 correction, mean correction }
+ // note: m2 = var * (n - 1)
+ cbuff_curr.w = cor.quickGetValue(r, 1); // count
+ cbuff_curr.m2._sum = quickGetValue(r, c) * (cbuff_curr.w - 1); // m2
+ cbuff_curr.mean._sum = cor.quickGetValue(r, 0); // mean
+ cbuff_curr.m2._correction = cor.quickGetValue(r, 2);
+ cbuff_curr.mean._correction = cor.quickGetValue(r, 3);
+
+ // extract partial values: { var | mean, count, m2 correction, mean correction }
+ // note: m2 = var * (n - 1)
+ cbuff_part.w = newWithCor.quickGetValue(r, c+2); // count
+ cbuff_part.m2._sum = newWithCor.quickGetValue(r, c) * (cbuff_part.w - 1); // m2
+ cbuff_part.mean._sum = newWithCor.quickGetValue(r, c+1); // mean
+ cbuff_part.m2._correction = newWithCor.quickGetValue(r, c+3);
+ cbuff_part.mean._correction = newWithCor.quickGetValue(r, c+4);
+
+ // calculate incremental aggregated variance
+ cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);
+
+ // store updated values: { var | mean, count, m2 correction, mean correction }
+ double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
+ quickSetValue(r, c, var);
+ cor.quickSetValue(r, 0, cbuff_curr.mean._sum); // mean
+ cor.quickSetValue(r, 1, cbuff_curr.w); // count
+ cor.quickSetValue(r, 2, cbuff_curr.m2._correction);
+ cor.quickSetValue(r, 3, cbuff_curr.mean._correction);
+ }
+ }
else
throw new DMLRuntimeException("unrecognized correctionLocation: "+aggOp.correctionLocation);
}
@@ -3377,17 +3455,8 @@ public class MatrixBlock extends MatrixValue implements Externalizable
}
}
}
- }/*else if(aggOp.correctionLocation==0)
- {
- for(int r=0; r<rlen; r++)
- for(int c=0; c<clen; c++)
- {
- //buffer._sum=this.getValue(r, c);
- //buffer._correction=0;
- //buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.getValue(r, c));
- setValue(r, c, this.getValue(r, c)+newWithCor.getValue(r, c));
- }
- }*/else if(aggOp.correctionLocation==CorrectionLocationType.LASTTWOROWS)
+ }
+ else if(aggOp.correctionLocation==CorrectionLocationType.LASTTWOROWS)
{
double n, n2, mu2;
for(int r=0; r<rlen-2; r++)
@@ -3425,6 +3494,82 @@ public class MatrixBlock extends MatrixValue implements Externalizable
quickSetValue(r, c+2, buffer._correction);
}
}
+ else if (aggOp.correctionLocation == CorrectionLocationType.LASTFOURROWS
+ && aggOp.increOp.fn instanceof CM
+ && ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
+ // create buffers to store results
+ CM_COV_Object cbuff_curr = new CM_COV_Object();
+ CM_COV_Object cbuff_part = new CM_COV_Object();
+
+ // perform incremental aggregation
+ for (int r=0; r<rlen-4; r++)
+ for (int c=0; c<clen; c++) {
+ // extract current values: { var | mean, count, m2 correction, mean correction }
+ // note: m2 = var * (n - 1)
+ cbuff_curr.w = quickGetValue(r+2, c); // count
+ cbuff_curr.m2._sum = quickGetValue(r, c) * (cbuff_curr.w - 1); // m2
+ cbuff_curr.mean._sum = quickGetValue(r+1, c); // mean
+ cbuff_curr.m2._correction = quickGetValue(r+3, c);
+ cbuff_curr.mean._correction = quickGetValue(r+4, c);
+
+ // extract partial values: { var | mean, count, m2 correction, mean correction }
+ // note: m2 = var * (n - 1)
+ cbuff_part.w = newWithCor.quickGetValue(r+2, c); // count
+ cbuff_part.m2._sum = newWithCor.quickGetValue(r, c) * (cbuff_part.w - 1); // m2
+ cbuff_part.mean._sum = newWithCor.quickGetValue(r+1, c); // mean
+ cbuff_part.m2._correction = newWithCor.quickGetValue(r+3, c);
+ cbuff_part.mean._correction = newWithCor.quickGetValue(r+4, c);
+
+ // calculate incremental aggregated variance
+ cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);
+
+ // store updated values: { var | mean, count, m2 correction, mean correction }
+ double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
+ quickSetValue(r, c, var);
+ quickSetValue(r+1, c, cbuff_curr.mean._sum); // mean
+ quickSetValue(r+2, c, cbuff_curr.w); // count
+ quickSetValue(r+3, c, cbuff_curr.m2._correction);
+ quickSetValue(r+4, c, cbuff_curr.mean._correction);
+ }
+ }
+ else if (aggOp.correctionLocation == CorrectionLocationType.LASTFOURCOLUMNS
+ && aggOp.increOp.fn instanceof CM
+ && ((CM) aggOp.increOp.fn).getAggOpType() == AggregateOperationTypes.VARIANCE) {
+ // create buffers to store results
+ CM_COV_Object cbuff_curr = new CM_COV_Object();
+ CM_COV_Object cbuff_part = new CM_COV_Object();
+
+ // perform incremental aggregation
+ for (int r=0; r<rlen; r++)
+ for (int c=0; c<clen-4; c++) {
+ // extract current values: { var | mean, count, m2 correction, mean correction }
+ // note: m2 = var * (n - 1)
+ cbuff_curr.w = quickGetValue(r, c+2); // count
+ cbuff_curr.m2._sum = quickGetValue(r, c) * (cbuff_curr.w - 1); // m2
+ cbuff_curr.mean._sum = quickGetValue(r, c+1); // mean
+ cbuff_curr.m2._correction = quickGetValue(r, c+3);
+ cbuff_curr.mean._correction = quickGetValue(r, c+4);
+
+ // extract partial values: { var | mean, count, m2 correction, mean correction }
+ // note: m2 = var * (n - 1)
+ cbuff_part.w = newWithCor.quickGetValue(r, c+2); // count
+ cbuff_part.m2._sum = newWithCor.quickGetValue(r, c) * (cbuff_part.w - 1); // m2
+ cbuff_part.mean._sum = newWithCor.quickGetValue(r, c+1); // mean
+ cbuff_part.m2._correction = newWithCor.quickGetValue(r, c+3);
+ cbuff_part.mean._correction = newWithCor.quickGetValue(r, c+4);
+
+ // calculate incremental aggregated variance
+ cbuff_curr = (CM_COV_Object) aggOp.increOp.fn.execute(cbuff_curr, cbuff_part);
+
+ // store updated values: { var | mean, count, m2 correction, mean correction }
+ double var = cbuff_curr.getRequiredResult(AggregateOperationTypes.VARIANCE);
+ quickSetValue(r, c, var);
+ quickSetValue(r, c+1, cbuff_curr.mean._sum); // mean
+ quickSetValue(r, c+2, cbuff_curr.w); // count
+ quickSetValue(r, c+3, cbuff_curr.m2._correction);
+ quickSetValue(r, c+4, cbuff_curr.mean._correction);
+ }
+ }
else
throw new DMLRuntimeException("unrecognized correctionLocation: "+aggOp.correctionLocation);
}
@@ -4293,6 +4438,12 @@ public class MatrixBlock extends MatrixValue implements Externalizable
case LASTTWOCOLUMNS:
tempCellIndex.column+=2;
break;
+ case LASTFOURROWS:
+ tempCellIndex.row+=4;
+ break;
+ case LASTFOURCOLUMNS:
+ tempCellIndex.column+=4;
+ break;
default:
throw new DMLRuntimeException("unrecognized correctionLocation: "+op.aggOp.correctionLocation);
}
@@ -4476,12 +4627,29 @@ public class MatrixBlock extends MatrixValue implements Externalizable
}
//determine number of rows/cols to be removed
- int step = ( correctionLocation==CorrectionLocationType.LASTTWOROWS
- || correctionLocation==CorrectionLocationType.LASTTWOCOLUMNS) ? 2 : 1;
+ int step;
+ switch (correctionLocation) {
+ case LASTROW:
+ case LASTCOLUMN:
+ step = 1;
+ break;
+ case LASTTWOROWS:
+ case LASTTWOCOLUMNS:
+ step = 2;
+ break;
+ case LASTFOURROWS:
+ case LASTFOURCOLUMNS:
+ step = 4;
+ break;
+ default:
+ step = 0;
+ }
+
- //e.g., colSums, colMeans, colMaxs, colMeans
- if( correctionLocation==CorrectionLocationType.LASTROW
- || correctionLocation==CorrectionLocationType.LASTTWOROWS )
+ //e.g., colSums, colMeans, colMaxs, colMeans, colVars
+ if( correctionLocation==CorrectionLocationType.LASTROW
+ || correctionLocation==CorrectionLocationType.LASTTWOROWS
+ || correctionLocation==CorrectionLocationType.LASTFOURROWS )
{
if( sparse ) //SPARSE
{
@@ -4502,9 +4670,10 @@ public class MatrixBlock extends MatrixValue implements Externalizable
rlen -= step;
}
- //e.g., rowSums, rowsMeans, rowsMaxs, rowsMeans
- if( correctionLocation==CorrectionLocationType.LASTCOLUMN
- || correctionLocation==CorrectionLocationType.LASTTWOCOLUMNS )
+ //e.g., rowSums, rowsMeans, rowsMaxs, rowsMeans, rowVars
+ else if( correctionLocation==CorrectionLocationType.LASTCOLUMN
+ || correctionLocation==CorrectionLocationType.LASTTWOCOLUMNS
+ || correctionLocation==CorrectionLocationType.LASTFOURCOLUMNS )
{
if(sparse) //SPARSE
{
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java b/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java
index 485b559..d9c8853 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/OperationsOnMatrixValues.java
@@ -204,6 +204,18 @@ public class OperationsOnMatrixValues
corRow=rlen;
corCol=2;
break;
+ case LASTFOURROWS:
+ outRow=rlen-4;
+ outCol=clen;
+ corRow=4;
+ corCol=clen;
+ break;
+ case LASTFOURCOLUMNS:
+ outRow=rlen;
+ outCol=clen-4;
+ corRow=rlen;
+ corCol=4;
+ break;
default:
throw new DMLRuntimeException("unrecognized correctionLocation: "+op.correctionLocation);
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/3206ac14/src/test/java/org/apache/sysml/test/integration/functions/aggregate/ColStdDevsTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/aggregate/ColStdDevsTest.java b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/ColStdDevsTest.java
new file mode 100644
index 0000000..6208aa4
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/aggregate/ColStdDevsTest.java
@@ -0,0 +1,295 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.aggregate;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+/**
+ * Test the column standard deviations function, "colSds(X)".
+ */
+public class ColStdDevsTest extends AutomatedTestBase {
+
+ private static final String TEST_NAME = "ColStdDevs";
+ private static final String TEST_DIR = "functions/aggregate/";
+ private static final String TEST_CLASS_DIR =
+ TEST_DIR + ColStdDevsTest.class.getSimpleName() + "/";
+ private static final String INPUT_NAME = "X";
+ private static final String OUTPUT_NAME = "colStdDevs";
+
+ private static final int rows = 1234;
+ private static final int cols = 1432;
+ private static final double sparsitySparse = 0.2;
+ private static final double sparsityDense = 0.7;
+ private static final double eps = Math.pow(10, -10);
+
+ private enum Sparsity {EMPTY, SPARSE, DENSE}
+ private enum DataType {MATRIX, ROWVECTOR, COLUMNVECTOR}
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME);
+ addTestConfiguration(TEST_NAME, config);
+ }
+
+ // Dense matrix
+ @Test
+ public void testColStdDevsDenseMatrixCP() {
+ testColStdDevs(TEST_NAME, Sparsity.DENSE, DataType.MATRIX, ExecType.CP);
+ }
+
+ @Test
+ public void testColStdDevsDenseMatrixSpark() {
+ testColStdDevs(TEST_NAME, Sparsity.DENSE, DataType.MATRIX, ExecType.SPARK);
+ }
+
+ @Test
+ public void testColStdDevsDenseMatrixMR() {
+ testColStdDevs(TEST_NAME, Sparsity.DENSE, DataType.MATRIX, ExecType.MR);
+ }
+
+ // Dense row vector
+ @Test
+ public void testColStdDevsDenseRowVectorCP() {
+ testColStdDevs(TEST_NAME, Sparsity.DENSE, DataType.ROWVECTOR, ExecType.CP);
+ }
+
+ @Test
+ public void testColStdDevsDenseRowVectorSpark() {
+ testColStdDevs(TEST_NAME, Sparsity.DENSE, DataType.ROWVECTOR, ExecType.SPARK);
+ }
+
+ @Test
+ public void testColStdDevsDenseRowVectorMR() {
+ testColStdDevs(TEST_NAME, Sparsity.DENSE, DataType.ROWVECTOR, ExecType.MR);
+ }
+
+ // Dense column vector
+ @Test
+ public void testColStdDevsDenseColVectorCP() {
+ testColStdDevs(TEST_NAME, Sparsity.DENSE, DataType.COLUMNVECTOR, ExecType.CP);
+ }
+
+ @Test
+ public void testColStdDevsDenseColVectorSpark() {
+ testColStdDevs(TEST_NAME, Sparsity.DENSE, DataType.COLUMNVECTOR, ExecType.SPARK);
+ }
+
+ @Test
+ public void testColStdDevsDenseColVectorMR() {
+ testColStdDevs(TEST_NAME, Sparsity.DENSE, DataType.COLUMNVECTOR, ExecType.MR);
+ }
+
+ // Sparse matrix
+ @Test
+ public void testColStdDevsSparseMatrixCP() {
+ testColStdDevs(TEST_NAME, Sparsity.SPARSE, DataType.MATRIX, ExecType.CP);
+ }
+
+ @Test
+ public void testColStdDevsSparseMatrixSpark() {
+ testColStdDevs(TEST_NAME, Sparsity.SPARSE, DataType.MATRIX, ExecType.SPARK);
+ }
+
+ @Test
+ public void testColStdDevsSparseMatrixMR() {
+ testColStdDevs(TEST_NAME, Sparsity.SPARSE, DataType.MATRIX, ExecType.MR);
+ }
+
+ // Sparse row vector
+ @Test
+ public void testColStdDevsSparseRowVectorCP() {
+ testColStdDevs(TEST_NAME, Sparsity.SPARSE, DataType.ROWVECTOR, ExecType.CP);
+ }
+
+ @Test
+ public void testColStdDevsSparseRowVectorSpark() {
+ testColStdDevs(TEST_NAME, Sparsity.SPARSE, DataType.ROWVECTOR, ExecType.SPARK);
+ }
+
+ @Test
+ public void testColStdDevsSparseRowVectorMR() {
+ testColStdDevs(TEST_NAME, Sparsity.SPARSE, DataType.ROWVECTOR, ExecType.MR);
+ }
+
+ // Sparse column vector
+ @Test
+ public void testColStdDevsSparseColVectorCP() {
+ testColStdDevs(TEST_NAME, Sparsity.SPARSE, DataType.COLUMNVECTOR, ExecType.CP);
+ }
+
+ @Test
+ public void testColStdDevsSparseColVectorSpark() {
+ testColStdDevs(TEST_NAME, Sparsity.SPARSE, DataType.COLUMNVECTOR, ExecType.SPARK);
+ }
+
+ @Test
+ public void testColStdDevsSparseColVectorMR() {
+ testColStdDevs(TEST_NAME, Sparsity.SPARSE, DataType.COLUMNVECTOR, ExecType.MR);
+ }
+
+ // Empty matrix
+ @Test
+ public void testColStdDevsEmptyMatrixCP() {
+ testColStdDevs(TEST_NAME, Sparsity.EMPTY, DataType.MATRIX, ExecType.CP);
+ }
+
+ @Test
+ public void testColStdDevsEmptyMatrixSpark() {
+ testColStdDevs(TEST_NAME, Sparsity.EMPTY, DataType.MATRIX, ExecType.SPARK);
+ }
+
+ @Test
+ public void testColStdDevsEmptyMatrixMR() {
+ testColStdDevs(TEST_NAME, Sparsity.EMPTY, DataType.MATRIX, ExecType.MR);
+ }
+
+ // Empty row vector
+ @Test
+ public void testColStdDevsEmptyRowVectorCP() {
+ testColStdDevs(TEST_NAME, Sparsity.EMPTY, DataType.ROWVECTOR, ExecType.CP);
+ }
+
+ @Test
+ public void testColStdDevsEmptyRowVectorSpark() {
+ testColStdDevs(TEST_NAME, Sparsity.EMPTY, DataType.ROWVECTOR, ExecType.SPARK);
+ }
+
+ @Test
+ public void testColStdDevsEmptyRowVectorMR() {
+ testColStdDevs(TEST_NAME, Sparsity.EMPTY, DataType.ROWVECTOR, ExecType.MR);
+ }
+
+ // Empty column vector
+ @Test
+ public void testColStdDevsEmptyColVectorCP() {
+ testColStdDevs(TEST_NAME, Sparsity.EMPTY, DataType.COLUMNVECTOR, ExecType.CP);
+ }
+
+ @Test
+ public void testColStdDevsEmptyColVectorSpark() {
+ testColStdDevs(TEST_NAME, Sparsity.EMPTY, DataType.COLUMNVECTOR, ExecType.SPARK);
+ }
+
+ @Test
+ public void testColStdDevsEmptyColVectorMR() {
+ testColStdDevs(TEST_NAME, Sparsity.EMPTY, DataType.COLUMNVECTOR, ExecType.MR);
+ }
+
+ /**
+ * Test the column standard deviations function, "colSds(X)", on
+ * dense/sparse matrices/vectors on the CP/Spark/MR platforms.
+ *
+ * @param testName The name of this test case.
+ * @param sparsity Selection between empty, sparse, and dense data.
+ * @param dataType Selection between a matrix, a row vector, and a
+ * column vector.
+ * @param platform Selection between CP/Spark/MR platforms.
+ */
+ private void testColStdDevs(String testName, Sparsity sparsity, DataType dataType,
+ ExecType platform) {
+ // Configure settings for this test case
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch (platform) {
+ case MR:
+ rtplatform = RUNTIME_PLATFORM.HADOOP;
+ break;
+ case SPARK:
+ rtplatform = RUNTIME_PLATFORM.SPARK;
+ break;
+ default:
+ rtplatform = RUNTIME_PLATFORM.SINGLE_NODE;
+ break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if (rtplatform == RUNTIME_PLATFORM.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ try {
+ // Create and load test configuration
+ getAndLoadTestConfiguration(testName);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testName + ".dml";
+ programArgs = new String[]{"-explain", "-stats", "-args",
+ input(INPUT_NAME), output(OUTPUT_NAME)};
+ fullRScriptName = HOME + testName + ".R";
+ rCmd = "Rscript" + " " + fullRScriptName + " " + inputDir() + " " + expectedDir();
+
+ // Generate data
+ // - sparsity
+ double sparsityVal;
+ switch (sparsity) {
+ case EMPTY:
+ sparsityVal = 0;
+ break;
+ case SPARSE:
+ sparsityVal = sparsitySparse;
+ break;
+ case DENSE:
+ default:
+ sparsityVal = sparsityDense;
+ }
+ // - size
+ int r;
+ int c;
+ switch (dataType) {
+ case ROWVECTOR:
+ r = 1;
+ c = cols;
+ break;
+ case COLUMNVECTOR:
+ r = rows;
+ c = 1;
+ break;
+ case MATRIX:
+ default:
+ r = rows;
+ c = cols;
+ }
+ // - generation
+ double[][] X = getRandomMatrix(r, c, -1, 1, sparsityVal, 7);
+ writeInputMatrixWithMTD(INPUT_NAME, X, true);
+
+ // Run DML and R scripts
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ // Compare output matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS(OUTPUT_NAME);
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS(OUTPUT_NAME);
+ TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+ }
+ finally {
+ // Reset settings
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}