You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2020/11/14 20:15:51 UTC
[systemds] 01/03: [SYSTEMDS-2727-9] Federated CM, Var, qsort & qpick
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
commit d000364e972c6ab2816e0a4990c7ab9c736042c8
Author: Olga <ov...@gmail.com>
AuthorDate: Wed Aug 26 14:44:04 2020 +0200
[SYSTEMDS-2727-9] Federated CM, Var, qsort & qpick
This commit adds more primitive federated instructions for statistics
calculations.
- Correlation
- Variance,
- Quantile sort and
- Quantile pick.
closes #1103
closes #1102
---
scripts/builtin/dist.dml | 5 +-
src/main/java/org/apache/sysds/lops/MMTSJ.java | 3 +
.../controlprogram/federated/FederationUtils.java | 163 +++++++++++++++-
.../fed/AggregateUnaryFEDInstruction.java | 55 +++++-
.../fed/BinaryMatrixScalarFEDInstruction.java | 2 +-
.../fed/CentralMomentFEDInstruction.java | 187 ++++++++++++++++++
.../runtime/instructions/fed/FEDInstruction.java | 4 +-
.../instructions/fed/FEDInstructionUtils.java | 120 ++++++++----
.../fed/QuantilePickFEDInstruction.java | 210 +++++++++++++++++++++
.../fed/QuantileSortFEDInstruction.java | 163 ++++++++++++++++
.../instructions/fed/ReorgFEDInstruction.java | 4 +-
.../instructions/fed/TsmmFEDInstruction.java | 4 +-
.../federated/algorithms/FederatedBivarTest.java | 3 +-
...eratedUnivarTest.java => FederatedCorTest.java} | 106 ++++++-----
.../federated/algorithms/FederatedGLMTest.java | 22 +--
.../federated/algorithms/FederatedKmeansTest.java | 42 ++---
.../federated/algorithms/FederatedL2SVMTest.java | 7 +-
.../federated/algorithms/FederatedLogRegTest.java | 5 +-
.../federated/algorithms/FederatedPCATest.java | 30 ++-
.../federated/algorithms/FederatedUnivarTest.java | 23 +--
.../FederatedVarTest.java} | 97 +++-------
.../functions/federated/io/FederatedSSLTest.java | 2 +-
.../federated/io/FederatedWriterTest.java | 2 +-
.../paramserv/FederatedParamservTest.java | 7 +-
.../primitives/FederatedBinaryMatrixTest.java | 16 +-
.../primitives/FederatedBinaryVectorTest.java | 12 +-
...rameTest.java => FederatedCastToFrameTest.java} | 6 +-
...rixTest.java => FederatedCastToMatrixTest.java} | 6 +-
.../primitives/FederatedCentralMomentTest.java | 147 +++++++++++++++
.../primitives/FederatedConstructionTest.java | 21 ++-
.../primitives/FederatedFullAggregateTest.java | 13 +-
.../FederatedMatrixScalarOperationsTest.java | 5 +-
.../primitives/FederatedMultiplyTest.java | 6 +-
.../primitives/FederatedQuantileTest.java | 165 ++++++++++++++++
.../primitives/FederatedQuantileWeightsTest.java | 140 ++++++++++++++
.../federated/primitives/FederatedRCBindTest.java | 23 ++-
.../primitives/FederatedRightIndexTest.java | 16 +-
.../primitives/FederatedRowColAggregateTest.java | 2 +-
.../federated/primitives/FederatedSplitTest.java | 15 +-
.../primitives/FederatedStatisticsTest.java | 6 +-
.../federated/primitives/FederatedSumTest.java | 20 +-
.../federated/FederatedCentralMomentTest.dml | 14 +-
.../FederatedCentralMomentTestReference.dml | 12 +-
.../functions/federated/FederatedCorTest.dml | 21 ++-
.../federated/FederatedCorTestReference.dml | 13 +-
.../functions/federated/FederatedVarTest.dml | 20 +-
.../federated/FederatedVarTestReference.dml | 14 +-
.../federated/quantile/FederatedIQRTest.dml | 13 +-
.../quantile/FederatedIQRTestReference.dml | 13 +-
.../federated/quantile/FederatedIQRWeightsTest.dml | 14 +-
.../quantile/FederatedIQRWeightsTestReference.dml | 14 +-
.../federated/quantile/FederatedMedianTest.dml | 13 +-
.../quantile/FederatedMedianTestReference.dml | 13 +-
.../quantile/FederatedMedianWeightsTest.dml | 14 +-
.../FederatedMedianWeightsTestReference.dml | 14 +-
.../federated/quantile/FederatedQuantileTest.dml | 13 +-
.../quantile/FederatedQuantileTestReference.dml | 13 +-
.../quantile/FederatedQuantileWeightsTest.dml | 14 +-
.../FederatedQuantileWeightsTestReference.dml | 14 +-
.../federated/quantile/FederatedQuantilesTest.dml | 16 +-
.../quantile/FederatedQuantilesTestReference.dml | 16 +-
.../quantile/FederatedQuantilesWeightsTest.dml | 17 +-
.../FederatedQuantilesWeightsTestReference.dml | 17 +-
63 files changed, 1712 insertions(+), 495 deletions(-)
diff --git a/scripts/builtin/dist.dml b/scripts/builtin/dist.dml
index 9c473d8..e5fe930 100644
--- a/scripts/builtin/dist.dml
+++ b/scripts/builtin/dist.dml
@@ -23,7 +23,8 @@
m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
G = X %*% t(X);
I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
+ Y = -2 * (G) + (diag(G) * I) + (I * t(diag(G)));
+# Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
Y = sqrt(Y);
Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysds/lops/MMTSJ.java b/src/main/java/org/apache/sysds/lops/MMTSJ.java
index 50448d1..73e27d8 100644
--- a/src/main/java/org/apache/sysds/lops/MMTSJ.java
+++ b/src/main/java/org/apache/sysds/lops/MMTSJ.java
@@ -41,6 +41,9 @@ public class MMTSJ extends Lop
public boolean isLeft(){
return (this == LEFT);
}
+ public boolean isRight(){
+ return (this == RIGHT);
+ }
}
private MMTSJType _type = null;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 37cb7d5..0a24cea 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -35,6 +35,7 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.Reques
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
+import org.apache.sysds.runtime.functionobjects.CM;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.functionobjects.Plus;
@@ -169,11 +170,85 @@ public class FederationUtils {
}
}
- public static ScalarObject aggScalar(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, FederationMap map) {
- if(!(aop.aggOp.increOp.fn instanceof KahanFunction || (aop.aggOp.increOp.fn instanceof Builtin &&
- (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN
- || ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)
- || aop.aggOp.increOp.fn instanceof Mean ))) {
+ public static MatrixBlock aggVar(Future<FederatedResponse>[] ffr, Future<FederatedResponse>[] meanFfr, FederationMap map, boolean isRowAggregate, boolean isScalar) {
+ try {
+// else if(aop.aggOp.increOp.fn instanceof CM) {
+// double var = ((ScalarObject) ffr[0].get().getData()[0]).getDoubleValue();
+// double mean = ((ScalarObject) meanFfr[0].get().getData()[0]).getDoubleValue();
+// long size = map.getFederatedRanges()[0].getSize();
+// for(int i = 0; i < ffr.length - 1; i++) {
+// long l = size + map.getFederatedRanges()[i+1].getSize();
+// double k = ((size * var) + (map.getFederatedRanges()[i+1].getSize() * ((ScalarObject) ffr[i+1].get().getData()[0]).getDoubleValue())) / l;
+// var = k + (size * map.getFederatedRanges()[i+1].getSize()) * Math.pow((mean - ((ScalarObject) meanFfr[i+1].get().getData()[0]).getDoubleValue()) / l, 2);
+// mean = (mean * size + ((ScalarObject) meanFfr[i+1].get().getData()[0]).getDoubleValue() * (map.getFederatedRanges()[i+1].getSize())) / l;
+// size = l;
+// System.out.println("Olga");
+// // long l = sizes[i] + sizes[i + 1];
+// // double k = Math.pow(means[i] - means[i+1], 2) * (sizes[i] * sizes[i+1]);
+// // k += ((sizes[i] * vars[i]) + (sizes[i+1] * vars[i+1])) * l;
+// // vars[i+1] = k / Math.pow(l, 2);
+// //
+// // means[i+1] = (means[i] * sizes[i] + means[i] * sizes[i]) / l;
+// // sizes[i+1] = l;
+// }
+// return new DoubleObject(var);
+//
+// }
+
+
+ FederatedRange[] ranges = map.getFederatedRanges();
+ BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
+ BinaryOperator minus = InstructionUtils.parseBinaryOperator("-");
+
+ ScalarOperator mult1 = InstructionUtils.parseScalarBinaryOperator("*", false);
+ ScalarOperator dev1 = InstructionUtils.parseScalarBinaryOperator("/", false);
+ ScalarOperator pow = InstructionUtils.parseScalarBinaryOperator("^2", false);
+
+ long size1 = isScalar ? ranges[0].getSize() : ranges[0].getSize(isRowAggregate ? 0 : 1);
+ MatrixBlock var1 = (MatrixBlock)ffr[0].get().getData()[0];
+ MatrixBlock mean1 = (MatrixBlock)meanFfr[0].get().getData()[0];
+ for(int i=0; i < ffr.length - 1; i++) {
+ MatrixBlock var2 = (MatrixBlock)ffr[i+1].get().getData()[0];
+ MatrixBlock mean2 = (MatrixBlock)meanFfr[i+1].get().getData()[0];
+ long size2 = isScalar ? ranges[i+1].getSize() : ranges[i+1].getSize(isRowAggregate ? 0 : 1);
+
+ mult1 = mult1.setConstant(size1);
+ var1 = var1.scalarOperations(mult1, new MatrixBlock());
+ mult1 = mult1.setConstant(size2);
+ var1 = var1.binaryOperationsInPlace(plus, var2.scalarOperations(mult1, new MatrixBlock()));
+ dev1 = dev1.setConstant(size1 + size2);
+ var1 = var1.scalarOperations(dev1, new MatrixBlock());
+
+ MatrixBlock tmp1 = (mean1.binaryOperationsInPlace(minus, mean2)).scalarOperations(dev1, new MatrixBlock());
+ tmp1 = tmp1.scalarOperations(pow, new MatrixBlock());
+ mult1 = mult1.setConstant(size1*size2);
+ tmp1 = tmp1.scalarOperations(mult1, new MatrixBlock());
+
+ var1 = tmp1.binaryOperationsInPlace(plus, var1);
+
+ // next mean
+ mult1 = mult1.setConstant(size1);
+ tmp1 = mean1.scalarOperations(mult1, new MatrixBlock());
+ mult1 = mult1.setConstant(size2);
+ mean1 = tmp1.binaryOperationsInPlace(plus, mean2.scalarOperations(mult1, new MatrixBlock()));
+ mean1 = mean1.scalarOperations(dev1, new MatrixBlock());
+
+ size1 = size1 + size2;
+ }
+
+ return var1;
+ }
+ catch (Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+
+ public static ScalarObject aggScalar(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, Future<FederatedResponse>[] meanFfr, FederationMap map) {
+ if(!(aop.aggOp.increOp.fn instanceof KahanFunction || aop.aggOp.increOp.fn instanceof CM ||
+ (aop.aggOp.increOp.fn instanceof Builtin &&
+ (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN ||
+ ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)
+ || aop.aggOp.increOp.fn instanceof Mean))) {
throw new DMLRuntimeException("Unsupported aggregation operator: "
+ aop.aggOp.increOp.getClass().getSimpleName());
}
@@ -187,6 +262,27 @@ public class FederationUtils {
else if( aop.aggOp.increOp.fn instanceof Mean ) {
return new DoubleObject(aggMean(ffr, map).getValue(0,0));
}
+ else if(aop.aggOp.increOp.fn instanceof CM) {
+ double var = ((ScalarObject) ffr[0].get().getData()[0]).getDoubleValue();
+ double mean = ((ScalarObject) meanFfr[0].get().getData()[0]).getDoubleValue();
+ long size = map.getFederatedRanges()[0].getSize();
+ for(int i = 0; i < ffr.length - 1; i++) {
+ long l = size + map.getFederatedRanges()[i+1].getSize();
+ double k = ((size * var) + (map.getFederatedRanges()[i+1].getSize() * ((ScalarObject) ffr[i+1].get().getData()[0]).getDoubleValue())) / l;
+ var = k + (size * map.getFederatedRanges()[i+1].getSize()) * Math.pow((mean - ((ScalarObject) meanFfr[i+1].get().getData()[0]).getDoubleValue()) / l, 2);
+ mean = (mean * size + ((ScalarObject) meanFfr[i+1].get().getData()[0]).getDoubleValue() * (map.getFederatedRanges()[i+1].getSize())) / l;
+ size = l;
+// long l = sizes[i] + sizes[i + 1];
+// double k = Math.pow(means[i] - means[i+1], 2) * (sizes[i] * sizes[i+1]);
+// k += ((sizes[i] * vars[i]) + (sizes[i+1] * vars[i+1])) * l;
+// vars[i+1] = k / Math.pow(l, 2);
+//
+// means[i+1] = (means[i] * sizes[i] + means[i] * sizes[i]) / l;
+// sizes[i+1] = l;
+ }
+ return new DoubleObject(var);
+
+ }
else { //if (aop.aggOp.increOp.fn instanceof KahanFunction)
double sum = 0; //uak+
for( Future<FederatedResponse> fr : ffr )
@@ -199,7 +295,7 @@ public class FederationUtils {
}
}
- public static MatrixBlock aggMatrix(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, FederationMap map) {
+ public static MatrixBlock aggMatrix(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, Future<FederatedResponse>[] meanFfr, FederationMap map) {
if (aop.isRowAggregate() && map.getType() == FederationMap.FType.ROW)
return bind(ffr, false);
else if (aop.isColAggregate() && map.getType() == FederationMap.FType.COL)
@@ -214,7 +310,10 @@ public class FederationUtils {
((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)) {
boolean isMin = ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN;
return aggMinMax(ffr,isMin,false, Optional.of(map.getType()));
- } else
+ } else if(aop.aggOp.increOp.fn instanceof CM) {
+ return aggVar(ffr, meanFfr, map, aop.isRowAggregate(), !(aop.isColAggregate() && aop.isRowAggregate())); //TODO
+ }
+ else
throw new DMLRuntimeException("Unsupported aggregation operator: "
+ aop.aggOp.increOp.fn.getClass().getSimpleName());
}
@@ -229,6 +328,56 @@ public class FederationUtils {
}
}
+ public static ScalarObject aggScalar(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, FederationMap map) {
+ if(!(aop.aggOp.increOp.fn instanceof KahanFunction || (aop.aggOp.increOp.fn instanceof Builtin &&
+ (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN
+ || ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)
+ || aop.aggOp.increOp.fn instanceof Mean ))) {
+ throw new DMLRuntimeException("Unsupported aggregation operator: "
+ + aop.aggOp.increOp.getClass().getSimpleName());
+ }
+
+ try {
+ if(aop.aggOp.increOp.fn instanceof Builtin){
+ // then we know it is a Min or Max based on the previous check.
+ boolean isMin = ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN;
+ return new DoubleObject(aggMinMax(ffr, isMin, true, Optional.empty()).getValue(0,0));
+ }
+ else if( aop.aggOp.increOp.fn instanceof Mean ) {
+ return new DoubleObject(aggMean(ffr, map).getValue(0,0));
+ }
+ else { //if (aop.aggOp.increOp.fn instanceof KahanFunction)
+ double sum = 0; //uak+
+ for( Future<FederatedResponse> fr : ffr )
+ sum += ((ScalarObject)fr.get().getData()[0]).getDoubleValue();
+ return new DoubleObject(sum);
+ }
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+
+ public static MatrixBlock aggMatrix(AggregateUnaryOperator aop, Future<FederatedResponse>[] ffr, FederationMap map) {
+ if (aop.isRowAggregate() && map.getType() == FederationMap.FType.ROW)
+ return bind(ffr, false);
+ else if (aop.isColAggregate() && map.getType() == FederationMap.FType.COL)
+ return bind(ffr, true);
+
+ if (aop.aggOp.increOp.fn instanceof KahanFunction)
+ return aggAdd(ffr);
+ else if( aop.aggOp.increOp.fn instanceof Mean )
+ return aggMean(ffr, map);
+ else if (aop.aggOp.increOp.fn instanceof Builtin &&
+ (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN ||
+ ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)) {
+ boolean isMin = ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN;
+ return aggMinMax(ffr,isMin,false, Optional.of(map.getType()));
+ } else
+ throw new DMLRuntimeException("Unsupported aggregation operator: "
+ + aop.aggOp.increOp.fn.getClass().getSimpleName());
+ }
+
public static FederationMap federateLocalData(CacheableData<?> data) {
long id = FederationUtils.getNextFedDataID();
FederatedLocalData federatedLocalData = new FederatedLocalData(id, data);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index d06dfaa..1429dd3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -25,13 +25,14 @@ import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
@@ -39,7 +40,17 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
CPOperand out, String opcode, String istr) {
super(FEDType.AggregateUnary, auop, in, out, opcode, istr);
}
-
+
+ protected AggregateUnaryFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+ String opcode, String istr) {
+ super(FEDType.AggregateUnary, op, in1, in2, out, opcode, istr);
+ }
+
+ protected AggregateUnaryFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
+ String opcode, String istr) {
+ super(FEDType.AggregateUnary, op, in1, in2, in3, out, opcode, istr);
+ }
+
public static AggregateUnaryFEDInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
@@ -53,6 +64,15 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
@Override
public void processInstruction(ExecutionContext ec) {
+ if (getOpcode().contains("var")) {
+ processVar(ec);
+ }else{
+ processDefault(ec);
+ }
+
+ }
+
+ private void processDefault(ExecutionContext ec){
AggregateUnaryOperator aop = (AggregateUnaryOperator) _optr;
MatrixObject in = ec.getMatrixObject(input1);
FederationMap map = in.getFedMapping();
@@ -70,4 +90,35 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
else
ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, tmp, map));
}
+
+ private void processVar(ExecutionContext ec){
+ AggregateUnaryOperator aop = (AggregateUnaryOperator) _optr;
+ MatrixObject in = ec.getMatrixObject(input1);
+ FederationMap map = in.getFedMapping();
+
+ // federated ranges mean for variance
+ Future<FederatedResponse>[] meanTmp = null;
+ if (getOpcode().contains("var")) {
+ String meanInstr = instString.replace(getOpcode(), getOpcode().replace("var", "mean"));
+ //create federated commands for aggregation
+ FederatedRequest meanFr1 = FederationUtils.callInstruction(meanInstr, output,
+ new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()});
+ FederatedRequest meanFr2 = new FederatedRequest(RequestType.GET_VAR, meanFr1.getID());
+ FederatedRequest meanFr3 = map.cleanup(getTID(), meanFr1.getID());
+ meanTmp = map.execute(getTID(), meanFr1, meanFr2, meanFr3);
+ }
+
+ //create federated commands for aggregation
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1}, new long[]{in.getFedMapping().getID()});
+ FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
+ FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
+
+ //execute federated commands and cleanups
+ Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2, fr3);
+ if( output.isScalar() )
+ ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp, meanTmp, map));
+ else
+ ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, meanTmp, tmp, map));
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
index b6ea1fb..895db4a 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
@@ -38,7 +38,7 @@ public class BinaryMatrixScalarFEDInstruction extends BinaryFEDInstruction
CPOperand matrix = input1.isMatrix() ? input1 : input2;
CPOperand scalar = input2.isScalar() ? input2 : input1;
MatrixObject mo = ec.getMatrixObject(matrix);
-
+
//prepare federated request matrix-scalar
FederatedRequest fr1 = !scalar.isLiteral() ?
mo.getFedMapping().broadcast(ec.getScalarInput(scalar)) : null;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
new file mode 100644
index 0000000..cc8e683
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/CentralMomentFEDInstruction.java
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.functionobjects.CM;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.CMOperator;
+
+public class CentralMomentFEDInstruction extends AggregateUnaryFEDInstruction {
+
+ private CentralMomentFEDInstruction(CMOperator cm, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
+ String opcode, String str) {
+ super(cm, in1, in2, in3, out, opcode, str);
+ }
+
+ public static CentralMomentFEDInstruction parseInstruction(String str) {
+ CPOperand in1 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
+ CPOperand in2 = null;
+ CPOperand in3 = null;
+ CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
+
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+
+ //check supported opcode
+ if( !opcode.equalsIgnoreCase("cm") ) {
+ throw new DMLRuntimeException("Unsupported opcode "+opcode);
+ }
+
+ if ( parts.length == 4 ) {
+ // Example: CP.cm.mVar0.Var1.mVar2; (without weights)
+ in2 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
+ parseUnaryInstruction(str, in1, in2, out);
+ }
+ else if ( parts.length == 5) {
+ // CP.cm.mVar0.mVar1.Var2.mVar3; (with weights)
+ in2 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
+ in3 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
+ parseUnaryInstruction(str, in1, in2, in3, out);
+ }
+
+ /*
+ * Exact order of the central moment MAY NOT be known at compilation time.
+ * We first try to parse the second argument as an integer, and if we fail,
+ * we simply pass -1 so that getCMAggOpType() picks up AggregateOperationTypes.INVALID.
+ * It must be updated at run time in processInstruction() method.
+ */
+
+ int cmOrder;
+ try {
+ if ( in3 == null ) {
+ cmOrder = Integer.parseInt(in2.getName());
+ }
+ else {
+ cmOrder = Integer.parseInt(in3.getName());
+ }
+ } catch(NumberFormatException e) {
+ cmOrder = -1; // unknown at compilation time
+ }
+
+ CMOperator.AggregateOperationTypes opType = CMOperator.getCMAggOpType(cmOrder);
+ CMOperator cm = new CMOperator(CM.getCMFnObject(opType), opType);
+ return new CentralMomentFEDInstruction(cm, in1, in2, in3, out, opcode, str);
+ }
+
+ @Override
+ public void processInstruction( ExecutionContext ec ) {
+ MatrixObject mo = ec.getMatrixObject(input1.getName());
+ ScalarObject order = ec.getScalarInput(input3==null ? input2 : input3);
+
+ CMOperator cm_op = ((CMOperator) _optr);
+ if(cm_op.getAggOpType() == CMOperator.AggregateOperationTypes.INVALID)
+ cm_op = cm_op.setCMAggOp((int) order.getLongValue());
+
+ FederationMap fedMapping = mo.getFedMapping();
+ List<CM_COV_Object> globalCmobj = new ArrayList<>();
+
+ long varID = FederationUtils.getNextFedDataID();
+ CMOperator finalCm_op = cm_op;
+ fedMapping.mapParallel(varID, (range, data) -> {
+
+ FederatedResponse response;
+ try {
+ if (input3 == null ) {
+ response = data.executeFederatedOperation(
+ new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new CentralMomentFEDInstruction.CMFunction(data.getVarID(), finalCm_op))).get();
+ } else {
+ MatrixBlock wtBlock = ec.getMatrixInput(input2.getName());
+
+ response = data.executeFederatedOperation(
+ new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new CentralMomentFEDInstruction.CMWeightsFunction(data.getVarID(), finalCm_op, wtBlock))).get();
+ }
+
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ synchronized(globalCmobj) {
+ globalCmobj.add((CM_COV_Object) response.getData()[0]);
+ }
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+
+ Optional<CM_COV_Object> res = globalCmobj.stream().reduce((arg0, arg1) -> (CM_COV_Object) finalCm_op.fn.execute(arg0, arg1));
+ try {
+ ec.setScalarOutput(output.getName(), new DoubleObject(res.get().getRequiredResult(finalCm_op)));
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ }
+
+ private static class CMFunction extends FederatedUDF {
+ private static final long serialVersionUID = 7460149207607220994L;
+ private final CMOperator _op;
+
+ public CMFunction (long input, CMOperator op) {
+ super(new long[] {input});
+ _op = op;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.cmOperations(_op));
+ }
+ }
+
+
+ private static class CMWeightsFunction extends FederatedUDF {
+ private static final long serialVersionUID = -3685746246551622021L;
+ private final CMOperator _op;
+ private final MatrixBlock _weights;
+
+ protected CMWeightsFunction(long input, CMOperator op, MatrixBlock weights) {
+ super(new long[] {input});
+ _op = op;
+ _weights = weights;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, mb.cmOperations(_op, _weights));
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 8094c96..0d912d8 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -37,7 +37,9 @@ public abstract class FEDInstruction extends Instruction {
Tsmm,
MMChain,
Reorg,
- MatrixIndexing
+ MatrixIndexing,
+ QSort,
+ QPick
}
protected final FEDType _fedType;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index ef66b66..1b095e1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap.FType;
@@ -30,18 +31,25 @@ import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOperationCode;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.QuantilePickSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
public class FEDInstructionUtils {
@@ -53,9 +61,9 @@ public class FEDInstructionUtils {
/**
* Check and replace CP instructions with federated instructions if the instruction match criteria.
- *
+ *
* @param inst The instruction to analyse
- * @param ec The Execution Context
+ * @param ec The Execution Context
* @return The potentially modified instruction
*/
public static Instruction checkAndReplaceCP(Instruction inst, ExecutionContext ec) {
@@ -82,14 +90,32 @@ public class FEDInstructionUtils {
if( mo.isFederated() )
fedinst = TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
}
- else if (inst instanceof AggregateUnaryCPInstruction) {
- AggregateUnaryCPInstruction instruction = (AggregateUnaryCPInstruction) inst;
- if( instruction.input1.isMatrix() && ec.containsVariable(instruction.input1) ) {
+ else if (inst instanceof UnaryCPInstruction && ! (inst instanceof IndexingCPInstruction)) {
+ UnaryCPInstruction instruction = (UnaryCPInstruction) inst;
+ if(inst instanceof ReorgCPInstruction && inst.getOpcode().equals("r'")) {
+ ReorgCPInstruction rinst = (ReorgCPInstruction) inst;
+ CacheableData<?> mo = ec.getCacheableData(rinst.input1);
+
+ if((mo instanceof MatrixObject || mo instanceof FrameObject) && mo.isFederated() )
+ fedinst = ReorgFEDInstruction.parseInstruction(rinst.getInstructionString());
+ }
+ else if(instruction.input1 != null && instruction.input1.isMatrix()
+ && ec.getMatrixObject(instruction.input1).isFederated()
+ && ec.containsVariable(instruction.input1)) {
+
MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
- if (mo1.isFederated() && instruction.getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT){
- LOG.debug("Federated UnaryAggregate");
+
+ if(instruction.getOpcode().equalsIgnoreCase("cm")) {
+ fedinst = CentralMomentFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ else if(inst instanceof AggregateUnaryCPInstruction &&
+ ((AggregateUnaryCPInstruction) instruction).getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT) {
fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
}
+ else if(inst.getOpcode().equalsIgnoreCase("qsort") &&
+ mo1.getFedMapping().getFederatedRanges().length == 1) {
+ fedinst = QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
}
}
else if (inst instanceof BinaryCPInstruction) {
@@ -98,6 +124,8 @@ public class FEDInstructionUtils {
|| (instruction.input2.isMatrix() && ec.getMatrixObject(instruction.input2).isFederated()) ) {
if(instruction.getOpcode().equals("append"))
fedinst = AppendFEDInstruction.parseInstruction(inst.getInstructionString());
+ else if(instruction.getOpcode().equals("qpick"))
+ fedinst = QuantilePickFEDInstruction.parseInstruction(inst.getInstructionString());
else
fedinst = BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
}
@@ -122,19 +150,14 @@ public class FEDInstructionUtils {
}
}
}
- else if(inst instanceof ReorgCPInstruction && inst.getOpcode().equals("r'")) {
- ReorgCPInstruction rinst = (ReorgCPInstruction) inst;
- CacheableData<?> mo = ec.getCacheableData(rinst.input1);
- if( mo.isFederated() )
- fedinst = ReorgFEDInstruction.parseInstruction(rinst.getInstructionString());
- }
- else if(inst instanceof MatrixIndexingCPInstruction && inst.getOpcode().equalsIgnoreCase("rightIndex")) {
+ else if(inst instanceof MatrixIndexingCPInstruction) {
// matrix indexing
+ LOG.info("Federated Indexing");
MatrixIndexingCPInstruction minst = (MatrixIndexingCPInstruction) inst;
- if(minst.input1.isMatrix()) {
- CacheableData<?> fo = ec.getCacheableData(minst.input1);
- if(fo.isFederated())
- fedinst = MatrixIndexingFEDInstruction.parseInstruction(minst.getInstructionString());
+ if(inst.getOpcode().equalsIgnoreCase("rightIndex")
+ && minst.input1.isMatrix() && ec.getCacheableData(minst.input1).isFederated()) {
+ LOG.info("Federated Right Indexing");
+ fedinst = MatrixIndexingFEDInstruction.parseInstruction(minst.getInstructionString());
}
}
else if(inst instanceof VariableCPInstruction ){
@@ -178,11 +201,47 @@ public class FEDInstructionUtils {
instruction.input1, instruction.input2, instruction.output, "ba+*", "FED...");
}
}
- else if (inst instanceof AggregateUnarySPInstruction) {
- AggregateUnarySPInstruction instruction = (AggregateUnarySPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if (data instanceof MatrixObject && ((MatrixObject) data).isFederated())
- fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
+ else if (inst instanceof UnarySPInstruction) {
+ if (inst instanceof CentralMomentSPInstruction) {
+ CentralMomentSPInstruction instruction = (CentralMomentSPInstruction) inst;
+ Data data = ec.getVariable(instruction.input1);
+ if (data instanceof MatrixObject && ((MatrixObject) data).isFederated())
+ fedinst = CentralMomentFEDInstruction.parseInstruction(inst.getInstructionString());
+ } else if (inst instanceof QuantileSortSPInstruction) {
+ QuantileSortSPInstruction instruction = (QuantileSortSPInstruction) inst;
+ Data data = ec.getVariable(instruction.input1);
+ if (data instanceof MatrixObject && ((MatrixObject) data).isFederated())
+ fedinst = QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ else if (inst instanceof AggregateUnarySPInstruction) {
+ AggregateUnarySPInstruction instruction = (AggregateUnarySPInstruction) inst;
+ Data data = ec.getVariable(instruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated())
+ fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ }
+ else if(inst instanceof BinarySPInstruction) {
+ if(inst instanceof QuantilePickSPInstruction) {
+ QuantilePickSPInstruction instruction = (QuantilePickSPInstruction) inst;
+ Data data = ec.getVariable(instruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated())
+ fedinst = QuantilePickFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ else if (inst instanceof AppendGAlignedSPInstruction) {
+ // TODO other Append Spark instructions
+ AppendGAlignedSPInstruction instruction = (AppendGAlignedSPInstruction) inst;
+ Data data = ec.getVariable(instruction.input1);
+ if (data instanceof MatrixObject && ((MatrixObject) data).isFederated()) {
+ fedinst = AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
+ }
+ }
+ else if (inst instanceof AppendGSPInstruction) {
+ AppendGSPInstruction instruction = (AppendGSPInstruction) inst;
+ Data data = ec.getVariable(instruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated()) {
+ fedinst = AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
+ }
+ }
}
else if (inst instanceof WriteSPInstruction) {
WriteSPInstruction instruction = (WriteSPInstruction) inst;
@@ -193,21 +252,6 @@ public class FEDInstructionUtils {
return VariableCPInstruction.parseInstruction(instruction.getInstructionString());
}
}
- else if (inst instanceof AppendGAlignedSPInstruction) {
- // TODO other Append Spark instructions
- AppendGAlignedSPInstruction instruction = (AppendGAlignedSPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if (data instanceof MatrixObject && ((MatrixObject) data).isFederated()) {
- fedinst = AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
- }
- }
- else if (inst instanceof AppendGSPInstruction) {
- AppendGSPInstruction instruction = (AppendGSPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if(data instanceof MatrixObject && ((MatrixObject) data).isFederated()) {
- fedinst = AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
- }
- }
//set thread id for federated context management
if( fedinst != null ) {
fedinst.setTID(ec.getTID());
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
new file mode 100644
index 0000000..f2052b5
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.sysds.lops.PickByCount.OperationTypes;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.DoubleObject;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class QuantilePickFEDInstruction extends BinaryFEDInstruction {
+
+ private final OperationTypes _type;
+
+ private QuantilePickFEDInstruction(Operator op, CPOperand in, CPOperand out, OperationTypes type, boolean inmem,
+ String opcode, String istr) {
+ this(op, in, null, out, type, inmem, opcode, istr);
+ }
+
+ private QuantilePickFEDInstruction(Operator op, CPOperand in, CPOperand in2, CPOperand out, OperationTypes type,
+ boolean inmem, String opcode, String istr) {
+ super(FEDType.QPick, op, in, in2, out, opcode, istr);
+ _type = type;
+ }
+
+ public static QuantilePickFEDInstruction parseInstruction ( String str ) {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+ if ( !opcode.equalsIgnoreCase("qpick") )
+ throw new DMLRuntimeException("Unknown opcode while parsing a QuantilePickCPInstruction: " + str);
+ //instruction parsing
+ if( parts.length == 4 ) {
+ //instructions of length 4 originate from unary - mr-iqm
+ //TODO this should be refactored to use pickvaluecount lops
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand out = new CPOperand(parts[3]);
+ OperationTypes ptype = OperationTypes.IQM;
+ boolean inmem = false;
+ return new QuantilePickFEDInstruction(null, in1, in2, out, ptype, inmem, opcode, str);
+ }
+ else if( parts.length == 5 ) {
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand out = new CPOperand(parts[2]);
+ OperationTypes ptype = OperationTypes.valueOf(parts[3]);
+ boolean inmem = Boolean.parseBoolean(parts[4]);
+ return new QuantilePickFEDInstruction(null, in1, out, ptype, inmem, opcode, str);
+ }
+ else if( parts.length == 6 ) {
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand out = new CPOperand(parts[3]);
+ OperationTypes ptype = OperationTypes.valueOf(parts[4]);
+ boolean inmem = Boolean.parseBoolean(parts[5]);
+ return new QuantilePickFEDInstruction(null, in1, in2, out, ptype, inmem, opcode, str);
+ }
+ return null;
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ MatrixObject in = ec.getMatrixObject(input1);
+ FederationMap fedMapping = in.getFedMapping();
+
+ List <Object> res = new ArrayList<>();
+ long varID = FederationUtils.getNextFedDataID();
+ fedMapping.mapParallel(varID, (range, data) -> {
+ FederatedResponse response;
+ try {
+ switch( _type )
+ {
+ case VALUEPICK:
+ if(input2.isScalar()) {
+ ScalarObject quantile = ec.getScalarInput(input2);
+ response = data.executeFederatedOperation(
+ new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,-1,
+ new QuantilePickFEDInstruction.ValuePick(data.getVarID(), quantile))).get();
+ }
+ else {
+ MatrixBlock quantiles = ec.getMatrixInput(input2.getName());
+ response = data.executeFederatedOperation(
+ new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,-1,
+ new QuantilePickFEDInstruction.ValuePick(data.getVarID(), quantiles))).get();
+ }
+ break;
+ case IQM:
+ response = data
+ .executeFederatedOperation(
+ new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new QuantilePickFEDInstruction.IQM(data.getVarID()))).get();
+ break;
+ case MEDIAN:
+ response = data
+ .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new QuantilePickFEDInstruction.Median(data.getVarID()))).get();
+ break;
+ default:
+ throw new DMLRuntimeException("Unsupported qpick operation type: "+_type);
+ }
+
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+
+ res.add(response.getData()[0]);
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+
+ assert res.size() == 1;
+
+ if(output.isScalar())
+ ec.setScalarOutput(output.getName(), new DoubleObject((double) res.get(0)));
+ else
+ ec.setMatrixOutput(output.getName(), (MatrixBlock) res.get(0));
+ }
+
+ private static class ValuePick extends FederatedUDF {
+
+ private static final long serialVersionUID = -2594912886841345102L;
+ private final MatrixBlock _quantiles;
+
+ protected ValuePick(long input, ScalarObject quantile) {
+ super(new long[] {input});
+ _quantiles = new MatrixBlock(quantile.getDoubleValue());
+ }
+
+ protected ValuePick(long input, MatrixBlock quantiles) {
+ super(new long[] {input});
+ _quantiles = quantiles;
+ }
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ MatrixBlock mb = ((MatrixObject)data[0]).acquireReadAndRelease();
+ MatrixBlock picked;
+ if (_quantiles.getLength() == 1) {
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
+ new Object[] {mb.pickValue(_quantiles.getValue(0, 0))});
+ }
+ else {
+ picked = mb.pickValues(_quantiles, new MatrixBlock());
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
+ new Object[] {picked});
+ }
+ }
+ }
+
+ private static class IQM extends FederatedUDF {
+
+ private static final long serialVersionUID = 2223186699111957677L;
+
+ protected IQM(long input) {
+ super(new long[] {input});
+ }
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ MatrixBlock mb = ((MatrixObject)data[0]).acquireReadAndRelease();
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
+ new Object[] {mb.interQuartileMean()});
+ }
+ }
+
+ private static class Median extends FederatedUDF {
+
+ private static final long serialVersionUID = -2808597461054603816L;
+
+ protected Median(long input) {
+ super(new long[] {input});
+ }
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ MatrixBlock mb = ((MatrixObject)data[0]).acquireReadAndRelease();
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
+ new Object[] {mb.median()});
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
new file mode 100644
index 0000000..0e994d9
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.lops.SortKeys;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+public class QuantileSortFEDInstruction extends UnaryFEDInstruction{
+
+ private QuantileSortFEDInstruction(CPOperand in, CPOperand out, String opcode, String istr) {
+ this(in, null, out, opcode, istr);
+ }
+
+ private QuantileSortFEDInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode,
+ String istr) {
+ super(FEDInstruction.FEDType.QSort, null, in1, in2, out, opcode, istr);
+ }
+
+ public static QuantileSortFEDInstruction parseInstruction ( String str ) {
+ CPOperand in1 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
+ CPOperand in2 = null;
+ CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
+
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+
+ if ( opcode.equalsIgnoreCase(SortKeys.OPCODE) ) {
+ if ( parts.length == 3 ) {
+ // Example: sort:mVar1:mVar2 (input=mVar1, output=mVar2)
+ parseUnaryInstruction(str, in1, out);
+ return new QuantileSortFEDInstruction(in1, out, opcode, str);
+ }
+ else if ( parts.length == 4 ) {
+ // Example: sort:mVar1:mVar2:mVar3 (input=mVar1, weights=mVar2, output=mVar3)
+ in2 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
+ parseUnaryInstruction(str, in1, in2, out);
+ return new QuantileSortFEDInstruction(in1, in2, out, opcode, str);
+ }
+ else {
+ throw new DMLRuntimeException("Invalid number of operands in instruction: " + str);
+ }
+ }
+ else {
+ throw new DMLRuntimeException("Unknown opcode while parsing a QuantileSortFEDInstruction: " + str);
+ }
+ }
+
+
+// @Override
+// public void processInstruction(ExecutionContext ec) {
+// MatrixObject in = ec.getMatrixObject(input1.getName());
+// FederationMap map = in.getFedMapping();
+//
+// //create federated commands for aggregation
+// FederatedRequest fr1 = FederationUtils
+// .callInstruction(instString, output, new CPOperand[] {input1}, new long[] {in.getFedMapping().getID()});
+// FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
+// FederatedRequest fr3 = map.cleanup(getTID(), fr1.getID());
+//
+// Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2, fr3);
+//
+// try {
+// Object d = tmp[0].get().getData()[0];
+// System.out.println(1);
+// }
+// catch(Exception e) {
+// e.printStackTrace();
+// }
+//
+// MatrixObject out = ec.getMatrixObject(output);
+// out.getDataCharacteristics().set(in.getDataCharacteristics());
+// out.setFedMapping(in.getFedMapping().copyWithNewID(fr2.getID()));
+// }
+
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ MatrixObject in = ec.getMatrixObject(input1);
+ FederationMap fedMapping = in.getFedMapping();
+
+ long varID = FederationUtils.getNextFedDataID();
+ FederationMap sortedMapping = fedMapping.mapParallel(varID, (range, data) -> {
+ try {
+ MatrixBlock wtBlock = null;
+ if (input2 != null) {
+ wtBlock = ec.getMatrixInput(input2.getName());
+ }
+
+ FederatedResponse response = data
+ .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new GetSorted(data.getVarID(), varID, wtBlock))).get();
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+
+
+ MatrixObject sorted = ec.getMatrixObject(output);
+ sorted.getDataCharacteristics().set(in.getDataCharacteristics());
+
+ // set the federated mapping for the matrix
+ sorted.setFedMapping(sortedMapping);
+ }
+
+ private static class GetSorted extends FederatedUDF {
+
+ private static final long serialVersionUID = -1969015577260167645L;
+ private final long _outputID;
+ private final MatrixBlock _weights;
+
+ protected GetSorted(long input, long outputID, MatrixBlock weights) {
+ super(new long[] {input});
+ _outputID = outputID;
+ _weights = weights;
+ }
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
+
+ MatrixBlock res = mb.sortOperations(_weights, new MatrixBlock());
+
+ MatrixObject mout = ExecutionContext.createMatrixObject(res);
+
+ // add it to the list of variables
+ ec.setVariable(String.valueOf(_outputID), mout);
+ // return schema
+ return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index a4b604b..d9e9d97 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -54,12 +54,12 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
if( !mo1.isFederated() )
throw new DMLRuntimeException("Federated Reorg: "
+ "Federated input expected, but invoked w/ "+mo1.isFederated());
-
+
//execute transpose at federated site
FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1}, new long[]{mo1.getFedMapping().getID()});
mo1.getFedMapping().execute(getTID(), true, fr1);
-
+
//drive output federated mapping
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(mo1.getNumColumns(),
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index fbe88d6..ed9615f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -25,12 +25,14 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import java.util.Arrays;
import java.util.concurrent.Future;
public class TsmmFEDInstruction extends BinaryFEDInstruction {
@@ -62,7 +64,7 @@ public class TsmmFEDInstruction extends BinaryFEDInstruction {
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
- if(mo1.isFederated() && _type.isLeft()) { // left tsmm
+ if((_type.isLeft() && mo1.isFederated(FederationMap.FType.ROW)) || (mo1.isFederated(FederationMap.FType.COL) && _type.isRight())) {
//construct commands: fed tsmm, retrieve results
FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
new CPOperand[]{input1}, new long[]{mo1.getFedMapping().getID()});
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
index 3a391d9..ced8bca 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedBivarTest.java
@@ -55,7 +55,8 @@ public class FederatedBivarTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {{10000, 16},
- // {2000, 32}, {1000, 64},
+ // {2000, 32},
+ // {1000, 64},
{10000, 128}});
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
similarity index 65%
copy from src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
copy to src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
index f24aa6c..1b06279 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedCorTest.java
@@ -19,7 +19,11 @@
package org.apache.sysds.test.functions.federated.algorithms;
-import org.apache.sysds.common.Types;
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
@@ -30,70 +34,67 @@ import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
-import java.util.Arrays;
-import java.util.Collection;
-
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-public class FederatedUnivarTest extends AutomatedTestBase {
+public class FederatedCorTest extends AutomatedTestBase {
+
+ private final static String TEST_NAME = "FederatedCorTest";
private final static String TEST_DIR = "functions/federated/";
- private final static String TEST_NAME = "FederatedUnivarTest";
- private final static String TEST_CLASS_DIR = TEST_DIR + FederatedUnivarTest.class.getSimpleName() + "/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FederatedCorTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
+
@Parameterized.Parameter()
public int rows;
@Parameterized.Parameter(1)
public int cols;
-
- @Override
- public void setUp() {
- TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"B"}));
- }
+ @Parameterized.Parameter(2)
+ public boolean rowPartitioned;
@Parameterized.Parameters
public static Collection<Object[]> data() {
- return Arrays.asList(new Object[][] {
- {10000, 16},
- {2000, 32}, {1000, 64}, {10000, 128}
- });
+ return Arrays.asList(new Object[][] {{1600, 8, true}});
}
- @Test
- public void federatedUnivarSinglenode() {
- federatedL2SVM(Types.ExecMode.SINGLE_NODE);
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S"}));
}
@Test
- public void federatedUnivarHybrid() {
- federatedL2SVM(Types.ExecMode.HYBRID);
+ public void testCorCP() {
+ runAggregateOperationTest(ExecMode.SINGLE_NODE);
}
- public void federatedL2SVM(Types.ExecMode execMode) {
- Types.ExecMode platformOld = setExecMode(execMode);
+ private void runAggregateOperationTest(ExecMode execMode) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ ExecMode platformOld = rtplatform;
+
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
// write input matrices
- int quarterCols = cols / 4;
- // We have two matrices handled by a single federated worker
- double[][] X1 = getRandomMatrix(rows, quarterCols, 1, 5, 1, 3);
- double[][] X2 = getRandomMatrix(rows, quarterCols, 1, 5, 1, 7);
- double[][] X3 = getRandomMatrix(rows, quarterCols, 1, 5, 1, 8);
- double[][] X4 = getRandomMatrix(rows, quarterCols, 1, 5, 1, 9);
-
- // write types matrix shape of (1, D)
- double [][] Y = getRandomMatrix(1, cols, 0, 3, 1, 9);
- Arrays.stream(Y[0]).forEach(Math::ceil);
-
- MatrixCharacteristics mc= new MatrixCharacteristics(rows, quarterCols, blocksize, rows * quarterCols);
+ int r = rows;
+ int c = cols / 4;
+ if(rowPartitioned) {
+ r = rows / 4;
+ c = cols;
+ }
+
+ double[][] X1 = getRandomMatrix(r, c, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 1, 5, 1, 9);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
writeInputMatrixWithMTD("X1", X1, false, mc);
writeInputMatrixWithMTD("X2", X2, false, mc);
writeInputMatrixWithMTD("X3", X3, false, mc);
writeInputMatrixWithMTD("X4", X4, false, mc);
- writeInputMatrixWithMTD("Y", Y, false);
// empty script name because we don't execute any script, just start the worker
fullDMLScriptName = "";
@@ -106,41 +107,46 @@ public class FederatedUnivarTest extends AutomatedTestBase {
Thread t3 = startLocalFedWorkerThread(port3);
Thread t4 = startLocalFedWorkerThread(port4);
+ rtplatform = execMode;
+ if(rtplatform == ExecMode.SPARK)
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
-
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"), input("Y"), expected("B")};
+ programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
+ expected("S"), Boolean.toString(rowPartitioned).toUpperCase()};
runTest(true, false, null, -1);
// Run actual dml script with federated matrix
+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "100", "-nvargs",
+ programArgs = new String[] {"-explain", "-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
- "in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
- "in_Y=" + input("Y"), // types
- "rows=" + rows, "cols=" + cols,
- "out=" + output("B")};
+ "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols,
+ "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+
runTest(true, false, null, -1);
// compare via files
compareResults(1e-9);
- TestUtils.shutdownThreads(t1, t2, t3, t4);
- // check for federated operations
- Assert.assertTrue(heavyHittersContainsString("fed_uacmax"));
+ // Assert.assertTrue(heavyHittersContainsString("k+"));
- //check that federated input files are still existing
+ // check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
- Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("Y")));
- resetExecMode(platformOld);
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
index 44de28f..1e608ce 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
@@ -57,7 +57,9 @@ public class FederatedGLMTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
// rows have to be even and > 1
- return Arrays.asList(new Object[][] {{10000, 10}, {1000, 100}, {2000, 43}});
+ return Arrays.asList(new Object[][] {
+ // {10000, 10}, {1000, 100},
+ {2000, 43}});
}
@Test
@@ -70,7 +72,6 @@ public class FederatedGLMTest extends AutomatedTestBase {
federatedGLM(Types.ExecMode.HYBRID);
}
-
public void federatedGLM(Types.ExecMode execMode) {
ExecMode platformOld = setExecMode(execMode);
@@ -99,8 +100,8 @@ public class FederatedGLMTest extends AutomatedTestBase {
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
- //
-
+ //
+
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-args", input("X1"), input("X2"), input("Y"), expected("Z")};
@@ -108,8 +109,7 @@ public class FederatedGLMTest extends AutomatedTestBase {
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats",
- "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
"in_Y=" + input("Y"), "out=" + output("Z")};
runTest(true, false, null, -1);
@@ -121,15 +121,15 @@ public class FederatedGLMTest extends AutomatedTestBase {
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
- Assert.assertTrue(heavyHittersContainsString("fed_uark+","fed_uarsqk+"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uark+", "fed_uarsqk+"));
Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
- //Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+ // Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
Assert.assertTrue(heavyHittersContainsString("fed_mmchain"));
-
- //check that federated input files are still existing
+
+ // check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
-
+
resetExecMode(platformOld);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
index 0dd339f..8a33d20 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
@@ -64,12 +64,13 @@ public class FederatedKmeansTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
// rows have to be even and > 1
- return Arrays.asList(new Object[][] {
- {10000, 10, 1, 1}, {2000, 50, 1, 1}, {1000, 100, 1, 1},
- {10000, 10, 2, 1}, {2000, 50, 2, 1}, {1000, 100, 2, 1}, //concurrent requests
- {10000, 10, 2, 2}, //repeated exec
- //TODO more runs e.g., 16 -> but requires rework RPC framework first
- //(e.g., see paramserv?)
+ return Arrays.asList(new Object[][] {{10000, 10, 1, 1},
+ // {2000, 50, 1, 1}, {1000, 100, 1, 1},
+ {10000, 10, 2, 1},
+ // {2000, 50, 2, 1}, {1000, 100, 2, 1}, //concurrent requests
+ {10000, 10, 2, 2}, // repeated exec
+ // TODO more runs e.g., 16 -> but requires rework RPC framework first
+ // (e.g., see paramserv?)
});
}
@@ -77,7 +78,7 @@ public class FederatedKmeansTest extends AutomatedTestBase {
public void federatedKmeansSinglenode() {
federatedKmeans(Types.ExecMode.SINGLE_NODE);
}
-
+
@Test
public void federatedKmeansHybrid() {
federatedKmeans(Types.ExecMode.HYBRID);
@@ -106,29 +107,26 @@ public class FederatedKmeansTest extends AutomatedTestBase {
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
-
-
+
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-args", input("X1"), input("X2"),
- String.valueOf(runs), expected("Z")};
+ programArgs = new String[] {"-args", input("X1"), input("X2"), String.valueOf(runs), expected("Z")};
runTest(true, false, null, -1);
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats",
- "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
"runs=" + String.valueOf(runs), "out=" + output("Z")};
-
- for( int i=0; i<rep; i++ ) {
+
+ for(int i = 0; i < rep; i++) {
ParForProgramBlock.resetWorkerIDs();
FederationUtils.resetFedDataID();
runTest(true, false, null, -1);
-
+
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
- //Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
+ // Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
Assert.assertTrue(heavyHittersContainsString("fed_uarmin"));
Assert.assertTrue(heavyHittersContainsString("fed_uark+"));
Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
@@ -137,16 +135,16 @@ public class FederatedKmeansTest extends AutomatedTestBase {
Assert.assertTrue(heavyHittersContainsString("fed_<="));
Assert.assertTrue(heavyHittersContainsString("fed_/"));
Assert.assertTrue(heavyHittersContainsString("fed_r'"));
-
- //check that federated input files are still existing
+
+ // check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
}
-
+
// compare via files
- //compareResults(1e-9); --> randomized
+ // compareResults(1e-9); --> randomized
TestUtils.shutdownThreads(t1, t2);
-
+
resetExecMode(platformOld);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
index e24935d..53bfc8d 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedL2SVMTest.java
@@ -55,7 +55,9 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
// rows have to be even and > 1
- return Arrays.asList(new Object[][] {{2, 1000}, {10, 100}, {100, 10}, {1000, 1}, {10, 2000}, {2000, 10}});
+ return Arrays.asList(new Object[][] {
+ // {2, 1000}, {10, 100}, {100, 10}, {1000, 1}, {10, 2000},
+ {2000, 10}});
}
@Test
@@ -102,8 +104,7 @@ public class FederatedL2SVMTest extends AutomatedTestBase {
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
-
-
+
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-args", input("X1"), input("X2"), input("Y"), expected("Z")};
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
index 0550ec9..42c614b 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
@@ -57,7 +57,9 @@ public class FederatedLogRegTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
// rows have to be even and > 1
- return Arrays.asList(new Object[][] {{10000, 10}, {1000, 100}, {2000, 43}});
+ return Arrays.asList(new Object[][] {
+ // {10000, 10}, {1000, 100},
+ {2000, 43}});
}
@Test
@@ -98,7 +100,6 @@ public class FederatedLogRegTest extends AutomatedTestBase {
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
-
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
index 33179b7..99c90ee 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
@@ -60,9 +60,10 @@ public class FederatedPCATest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
// rows have to be even and > 1
- return Arrays.asList(new Object[][] {
- {10000, 10, false}, {2000, 50, false}, {1000, 100, false},
- {10000, 10, true}, {2000, 50, true}, {1000, 100, true}
+ return Arrays.asList(new Object[][] {{10000, 10, false},
+ // {2000, 50, false}, {1000, 100, false},
+ {10000, 10, true},
+ // {2000, 50, true}, {1000, 100, true}
});
}
@@ -70,7 +71,7 @@ public class FederatedPCATest extends AutomatedTestBase {
public void federatedPCASinglenode() {
federatedL2SVM(Types.ExecMode.SINGLE_NODE);
}
-
+
@Test
public void federatedPCAHybrid() {
federatedL2SVM(Types.ExecMode.HYBRID);
@@ -89,7 +90,7 @@ public class FederatedPCATest extends AutomatedTestBase {
double[][] X2 = getRandomMatrix(quarterRows, cols, 0, 1, 1, 7);
double[][] X3 = getRandomMatrix(quarterRows, cols, 0, 1, 1, 8);
double[][] X4 = getRandomMatrix(quarterRows, cols, 0, 1, 1, 9);
- MatrixCharacteristics mc= new MatrixCharacteristics(quarterRows, cols, blocksize, quarterRows * cols);
+ MatrixCharacteristics mc = new MatrixCharacteristics(quarterRows, cols, blocksize, quarterRows * cols);
writeInputMatrixWithMTD("X1", X1, false, mc);
writeInputMatrixWithMTD("X2", X2, false, mc);
writeInputMatrixWithMTD("X3", X3, false, mc);
@@ -108,8 +109,7 @@ public class FederatedPCATest extends AutomatedTestBase {
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
-
-
+
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-stats", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
@@ -118,37 +118,35 @@ public class FederatedPCATest extends AutomatedTestBase {
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "-nvargs",
- "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ programArgs = new String[] {"-stats", "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
- "in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
- "rows=" + rows, "cols=" + cols,
+ "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols,
"scaleAndShift=" + String.valueOf(scaleAndShift).toUpperCase(), "out=" + output("Z")};
runTest(true, false, null, -1);
// compare via files
compareResults(1e-9);
TestUtils.shutdownThreads(t1, t2, t3, t4);
-
+
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
Assert.assertTrue(heavyHittersContainsString("fed_tsmm"));
- if( scaleAndShift ) {
+ if(scaleAndShift) {
Assert.assertTrue(heavyHittersContainsString("fed_uacsqk+"));
Assert.assertTrue(heavyHittersContainsString("fed_uacmean"));
Assert.assertTrue(heavyHittersContainsString("fed_-"));
Assert.assertTrue(heavyHittersContainsString("fed_/"));
Assert.assertTrue(heavyHittersContainsString("fed_replace"));
}
-
- //check that federated input files are still existing
+
+ // check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
-
+
resetExecMode(platformOld);
}
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
index f24aa6c..588796a 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedUnivarTest.java
@@ -55,9 +55,8 @@ public class FederatedUnivarTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
- {10000, 16},
- {2000, 32}, {1000, 64}, {10000, 128}
- });
+ // {10000, 16},{2000, 32}, {1000, 64},
+ {10000, 128}});
}
@Test
@@ -85,10 +84,10 @@ public class FederatedUnivarTest extends AutomatedTestBase {
double[][] X4 = getRandomMatrix(rows, quarterCols, 1, 5, 1, 9);
// write types matrix shape of (1, D)
- double [][] Y = getRandomMatrix(1, cols, 0, 3, 1, 9);
+ double[][] Y = getRandomMatrix(1, cols, 0, 3, 1, 9);
Arrays.stream(Y[0]).forEach(Math::ceil);
- MatrixCharacteristics mc= new MatrixCharacteristics(rows, quarterCols, blocksize, rows * quarterCols);
+ MatrixCharacteristics mc = new MatrixCharacteristics(rows, quarterCols, blocksize, rows * quarterCols);
writeInputMatrixWithMTD("X1", X1, false, mc);
writeInputMatrixWithMTD("X2", X2, false, mc);
writeInputMatrixWithMTD("X3", X3, false, mc);
@@ -108,23 +107,21 @@ public class FederatedUnivarTest extends AutomatedTestBase {
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
-
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"), input("Y"), expected("B")};
+ programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
+ input("Y"), expected("B")};
runTest(true, false, null, -1);
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "100", "-nvargs",
+ programArgs = new String[] {"-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
- "in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
- "in_Y=" + input("Y"), // types
- "rows=" + rows, "cols=" + cols,
- "out=" + output("B")};
+ "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "in_Y=" + input("Y"), // types
+ "rows=" + rows, "cols=" + cols, "out=" + output("B")};
runTest(true, false, null, -1);
// compare via files
@@ -134,7 +131,7 @@ public class FederatedUnivarTest extends AutomatedTestBase {
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_uacmax"));
- //check that federated input files are still existing
+ // check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
similarity index 64%
copy from src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
copy to src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
index 0adcb15..9579fef 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedVarTest.java
@@ -17,11 +17,13 @@
* under the License.
*/
-package org.apache.sysds.test.functions.federated.primitives;
+package org.apache.sysds.test.functions.federated.algorithms;
import java.util.Arrays;
import java.util.Collection;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -36,89 +38,53 @@ import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-public class FederatedRightIndexTest extends AutomatedTestBase {
- // private static final Log LOG = LogFactory.getLog(FederatedRightIndexTest.class.getName());
+public class FederatedVarTest extends AutomatedTestBase {
- private final static String TEST_NAME1 = "FederatedRightIndexRightTest";
- private final static String TEST_NAME2 = "FederatedRightIndexLeftTest";
- private final static String TEST_NAME3 = "FederatedRightIndexFullTest";
+ private static final Log LOG = LogFactory.getLog(FederatedVarTest.class.getName());
+ private final static String TEST_NAME = "FederatedVarTest";
private final static String TEST_DIR = "functions/federated/";
- private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRightIndexTest.class.getSimpleName() + "/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + FederatedVarTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
+
@Parameterized.Parameter()
public int rows;
@Parameterized.Parameter(1)
public int cols;
-
@Parameterized.Parameter(2)
- public int from;
-
- @Parameterized.Parameter(3)
- public int to;
-
- @Parameterized.Parameter(4)
public boolean rowPartitioned;
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
- {20, 10, 6, 8, true},
- {20, 10, 1, 1, true},
- {20, 10, 2, 10, true},
- // {20, 10, 2, 10, true},
- // {20, 12, 2, 10, false}, {20, 12, 1, 4, false}
+ // {10, 1000, false},
+ {100, 4, false},
+ // {36, 1000, true},
+ {1000, 10, true},
+ // {4, 100, true}
+ // {1600, 8, false},
});
}
- private enum IndexType {
- RIGHT, LEFT, FULL
- }
-
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
- addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"}));
- addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"}));
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S.scalar"}));
}
@Test
- public void testRightIndexRightDenseMatrixCP() {
- runAggregateOperationTest(IndexType.RIGHT, ExecMode.SINGLE_NODE);
+ public void testVarCP() {
+ runAggregateOperationTest(ExecMode.SINGLE_NODE);
}
- @Test
- public void testRightIndexLeftDenseMatrixCP() {
- runAggregateOperationTest(IndexType.LEFT, ExecMode.SINGLE_NODE);
- }
-
- @Test
- public void testRightIndexFullDenseMatrixCP() {
- runAggregateOperationTest(IndexType.FULL, ExecMode.SINGLE_NODE);
- }
-
- private void runAggregateOperationTest(IndexType type, ExecMode execMode) {
+ private void runAggregateOperationTest(ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
if(rtplatform == ExecMode.SPARK)
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
- String TEST_NAME = null;
- switch(type) {
- case RIGHT:
- TEST_NAME = TEST_NAME1;
- break;
- case LEFT:
- TEST_NAME = TEST_NAME2;
- break;
- case FULL:
- TEST_NAME = TEST_NAME3;
- break;
- }
-
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
@@ -153,35 +119,34 @@ public class FederatedRightIndexTest extends AutomatedTestBase {
Thread t4 = startLocalFedWorkerThread(port4);
rtplatform = execMode;
- if(rtplatform == ExecMode.SPARK) {
- System.out.println(7);
+ if(rtplatform == ExecMode.SPARK)
DMLScript.USE_LOCAL_SPARK_CONFIG = true;
- }
+
TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
loadTestConfiguration(config);
// Run reference dml script with normal matrix
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-args", input("X1"), input("X2"), input("X3"), input("X4"), String.valueOf(from),
- String.valueOf(to), Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
- // LOG.error(runTest(null));
- runTest(null);
+ programArgs = new String[] {"-stats", "100", "-args", input("X1"), input("X2"), input("X3"), input("X4"),
+ expected("S"), Boolean.toString(rowPartitioned).toUpperCase()};
+ LOG.debug(runTest(null));
+
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-stats", "100", "-nvargs",
+ programArgs = new String[] {"-explain", "-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
"in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
- "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols, "from=" + from,
- "to=" + to, "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+ "in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols,
+ "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+
+ LOG.debug(runTest(null));
- // LOG.error(runTest(null));
- runTest(null);
// compare via files
- compareResults(1e-9);
+ compareResults(0.05);
- Assert.assertTrue(heavyHittersContainsString("fed_rightIndex"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uavar"));
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
index 1088b68..6ec2f40 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedSSLTest.java
@@ -45,7 +45,7 @@ public class FederatedSSLTest extends AutomatedTestBase {
// This test use the same scripts as the Federated Reader tests, just with SSL enabled.
private final static String TEST_DIR = "functions/federated/io/";
private final static String TEST_NAME = "FederatedReaderTest";
- private final static String TEST_CLASS_DIR = TEST_DIR + FederatedReaderTest.class.getSimpleName() + "/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedSSLTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
private final static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR + "SSLConfig.xml");
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
index 587da82..c8a50fe 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedWriterTest.java
@@ -59,7 +59,7 @@ public class FederatedWriterTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
// number of rows or cols has to be >= number of federated locations.
- return Arrays.asList(new Object[][] {{10, 13, true, 2},});
+ return Arrays.asList(new Object[][] {{10, 13, true, 2}});
}
@Test
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
index ada7412..c4d04ea 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/paramserv/FederatedParamservTest.java
@@ -24,6 +24,8 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.List;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.test.AutomatedTestBase;
@@ -39,6 +41,7 @@ import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class FederatedParamservTest extends AutomatedTestBase {
+ private static final Log LOG = LogFactory.getLog(FederatedParamservTest.class.getName());
private final static String TEST_DIR = "functions/federated/paramserv/";
private final static String TEST_NAME = "FederatedParamservTest";
private final static String TEST_CLASS_DIR = TEST_DIR + FederatedParamservTest.class.getSimpleName() + "/";
@@ -159,9 +162,7 @@ public class FederatedParamservTest extends AutomatedTestBase {
}
programArgs = programArgsList.toArray(new String[0]);
- // ByteArrayOutputStream stdout =
- runTest(null);
- // System.out.print(stdout.toString());
+ LOG.debug(runTest(null));
Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst());
// cleanup
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java
index efd98c2..11f2bd4 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryMatrixTest.java
@@ -56,8 +56,10 @@ public class FederatedBinaryMatrixTest extends AutomatedTestBase {
public static Collection<Object[]> data() {
// rows have to be even and > 1
return Arrays.asList(new Object[][] {
- {2, 1000},
- {10, 100}, {100, 10}, {1000, 1}, {10, 2000}, {2000, 10}
+ // {2, 1000},
+ {10, 100},
+ // {100, 10}, {1000, 1},
+ // {10, 2000}, {2000, 10}
});
}
@@ -66,12 +68,6 @@ public class FederatedBinaryMatrixTest extends AutomatedTestBase {
federatedMultiply(Types.ExecMode.SINGLE_NODE);
}
- /*
- * FIXME spark execution mode support
- *
- * @Test public void federatedMultiplySP() { federatedMultiply(Types.ExecMode.SPARK); }
- */
-
public void federatedMultiply(Types.ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
@@ -109,7 +105,7 @@ public class FederatedBinaryMatrixTest extends AutomatedTestBase {
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2"), "Y1=" + input("Y1"),
"Y2=" + input("Y2"), "Z=" + expected("Z")};
- runTest(true, false, null, -1);
+ runTest(null);
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
@@ -117,7 +113,7 @@ public class FederatedBinaryMatrixTest extends AutomatedTestBase {
"X2=" + TestUtils.federatedAddress(port2, input("X2")),
"Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
"Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z")};
- runTest(true, false, null, -1);
+ runTest(null);
// compare via files
compareResults(1e-9);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
index c0a53d4..f11b7be 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
@@ -57,8 +57,10 @@ public class FederatedBinaryVectorTest extends AutomatedTestBase {
public static Collection<Object[]> data() {
// rows have to be even and > 1
return Arrays.asList(new Object[][] {
- {2, 1000},
- {10, 100}, {100, 10}, {1000, 1}, {10, 2000}, {2000, 10}
+ // {2, 1000},
+ // {10, 100},
+ {100, 10},
+ // {1000, 1}, {10, 2000}, {2000, 10}
});
}
@@ -67,12 +69,6 @@ public class FederatedBinaryVectorTest extends AutomatedTestBase {
federatedMultiply(Types.ExecMode.SINGLE_NODE);
}
- /*
- * FIXME spark execution mode support
- *
- * @Test public void federatedMultiplySP() { federatedMultiply(Types.ExecMode.SPARK); }
- */
-
public void federatedMultiply(Types.ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
Types.ExecMode platformOld = rtplatform;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToFrameTest.java
similarity index 95%
rename from src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java
rename to src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToFrameTest.java
index 5e05bf5..b67cc93 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToFrameTest.java
@@ -38,12 +38,12 @@ import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-public class FederetedCastToFrameTest extends AutomatedTestBase {
- private static final Log LOG = LogFactory.getLog(FederetedCastToFrameTest.class.getName());
+public class FederatedCastToFrameTest extends AutomatedTestBase {
+ private static final Log LOG = LogFactory.getLog(FederatedCastToFrameTest.class.getName());
private final static String TEST_DIR = "functions/federated/primitives/";
private final static String TEST_NAME = "FederatedCastToFrameTest";
- private final static String TEST_CLASS_DIR = TEST_DIR + FederetedCastToFrameTest.class.getSimpleName() + "/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedCastToFrameTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@Parameterized.Parameter()
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToMatrixTest.java
similarity index 96%
rename from src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToMatrixTest.java
rename to src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToMatrixTest.java
index b075e47..57ffacf 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToMatrixTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCastToMatrixTest.java
@@ -48,12 +48,12 @@ import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-public class FederetedCastToMatrixTest extends AutomatedTestBase {
- private static final Log LOG = LogFactory.getLog(FederetedCastToMatrixTest.class.getName());
+public class FederatedCastToMatrixTest extends AutomatedTestBase {
+ private static final Log LOG = LogFactory.getLog(FederatedCastToMatrixTest.class.getName());
private final static String TEST_DIR = "functions/federated/primitives/";
private final static String TEST_NAME = "FederatedCastToMatrixTest";
- private final static String TEST_CLASS_DIR = TEST_DIR + FederetedCastToMatrixTest.class.getSimpleName() + "/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedCastToMatrixTest.class.getSimpleName() + "/";
@Parameterized.Parameter()
public int rows;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCentralMomentTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCentralMomentTest.java
new file mode 100644
index 0000000..4d644ce
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCentralMomentTest.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedCentralMomentTest extends AutomatedTestBase {
+
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_NAME = "FederatedCentralMomentTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedCentralMomentTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+
+ @Parameterized.Parameter(1)
+ public int k;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {
+ {1000, 2},
+ {1000, 3},
+ {1000, 4}
+ });
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"S.scalar"}));
+ }
+
+ @Test
+ public void federatedCentralMomentCP() { federatedCentralMoment(Types.ExecMode.SINGLE_NODE); }
+
+ @Test
+ @Ignore
+ public void federatedCentralMomentSP() { federatedCentralMoment(Types.ExecMode.SPARK); }
+
+ public void federatedCentralMoment(Types.ExecMode execMode) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ int r = rows / 4;
+
+ double[][] X1 = getRandomMatrix(r, 1, 1, 5, 1, 3);
+ double[][] X2 = getRandomMatrix(r, 1, 1, 5, 1, 7);
+ double[][] X3 = getRandomMatrix(r, 1, 1, 5, 1, 8);
+ double[][] X4 = getRandomMatrix(r, 1, 1, 5, 1, 9);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(r, 1, blocksize, r);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+ writeInputMatrixWithMTD("X2", X2, false, mc);
+ writeInputMatrixWithMTD("X3", X3, false, mc);
+ writeInputMatrixWithMTD("X4", X4, false, mc);
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ int port3 = getRandomAvailablePort();
+ int port4 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1);
+ Thread t2 = startLocalFedWorkerThread(port2);
+ Thread t3 = startLocalFedWorkerThread(port3);
+ Thread t4 = startLocalFedWorkerThread(port4);
+
+ // reference file should not be written to hdfs, so we set platform here
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ // Run reference dml script with normal matrix for Row/Col
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-stats", "100", "-args",
+ input("X1"), input("X2"), input("X3"), input("X4"), expected("S"), String.valueOf(k)};
+ runTest(null);
+
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2, input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3, input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4, input("X4")),
+ "rows=" + rows,
+ "cols=" + 1,
+ "out_S=" + output("S"),
+ "k=" + k};
+ runTest(null);
+
+ // compare all sums via files
+ compareResults(0.01);
+
+ Assert.assertTrue(heavyHittersContainsString("fed_cm"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedConstructionTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedConstructionTest.java
index 8a3b4d1..edfd1b4 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedConstructionTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedConstructionTest.java
@@ -19,6 +19,13 @@
package org.apache.sysds.test.functions.federated.primitives;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.FileFormat;
@@ -26,17 +33,11 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.List;
-
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class FederatedConstructionTest extends AutomatedTestBase {
@@ -56,7 +57,10 @@ public class FederatedConstructionTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
// cols have to be dividable by 4 for Frame tests
- return Arrays.asList(new Object[][] {{1, 1024}, {8, 256}, {256, 8}, {1024, 4}, {16, 2048}, {2048, 32}});
+ return Arrays.asList(new Object[][] {
+ // {1, 1024}, {8, 256}, {256, 8}, {1024, 4}, {16, 2048},
+ {2048, 32}
+ });
}
@Override
@@ -71,6 +75,7 @@ public class FederatedConstructionTest extends AutomatedTestBase {
}
@Test
+ @Ignore
public void federatedMatrixConstructionSP() {
federatedMatrixConstruction(Types.ExecMode.SPARK);
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java
index 00f0f50..1617ab6 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedFullAggregateTest.java
@@ -31,6 +31,7 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
+import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -57,7 +58,13 @@ public class FederatedFullAggregateTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(
- new Object[][] {{10, 1000, false}, {100, 4, false}, {36, 1000, true}, {1000, 10, true}, {4, 100, true}});
+ new Object[][] {
+ // {10, 1000, false},
+ {100, 4, false},
+ // {36, 1000, true},
+ // {1000, 10, true},
+ {4, 100, true}
+ });
}
private enum OpType {
@@ -94,21 +101,25 @@ public class FederatedFullAggregateTest extends AutomatedTestBase {
}
@Test
+ @Ignore
public void testSumDenseMatrixSP() {
runColAggregateOperationTest(OpType.SUM, ExecType.SPARK);
}
@Test
+ @Ignore
public void testMeanDenseMatrixSP() {
runColAggregateOperationTest(OpType.MEAN, ExecType.SPARK);
}
@Test
+ @Ignore
public void testMaxDenseMatrixSP() {
runColAggregateOperationTest(OpType.MAX, ExecType.SPARK);
}
@Test
+ @Ignore
public void testMinDenseMatrixSP() {
runColAggregateOperationTest(OpType.MIN, ExecType.SPARK);
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMatrixScalarOperationsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMatrixScalarOperationsTest.java
index becb246..bd19ee8 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMatrixScalarOperationsTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMatrixScalarOperationsTest.java
@@ -37,7 +37,10 @@ import static java.lang.Thread.sleep;
public class FederatedMatrixScalarOperationsTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Iterable<Object[]> data() {
- return Arrays.asList(new Object[][] {{100, 100}, {10000, 100},});
+ return Arrays.asList(new Object[][] {
+ {100, 100},
+ // {10000, 100}
+ });
}
// internals 4 parameterized tests
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
index 3220e1a..4170914 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMultiplyTest.java
@@ -56,7 +56,11 @@ public class FederatedMultiplyTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
// rows have to be even and > 1
- return Arrays.asList(new Object[][] {{2, 1000}, {10, 100}, {100, 10}, {1000, 1}, {10, 2000}, {2000, 10}});
+ return Arrays.asList(new Object[][] {
+ // {2, 1000}, {10, 100},
+ {100, 10}, {1000, 1},
+ // {10, 2000}, {2000, 10}
+ });
}
@Test
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
new file mode 100644
index 0000000..1a71279
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedQuantileTest extends AutomatedTestBase {
+
+ private final static String TEST_DIR = "functions/federated/quantile/";
+ private final static String TEST_NAME1 = "FederatedQuantileTest";
+ private final static String TEST_NAME2 = "FederatedMedianTest";
+ private final static String TEST_NAME3 = "FederatedIQMTest";
+ private final static String TEST_NAME4 = "FederatedQuantilesTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedQuantileTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {{1000}});
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S"}));
+ }
+
+ @Test
+ public void federatedQuantile1CP() { federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, 0.25); }
+
+ @Test
+ public void federatedQuantile2CP() { federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, 0.5); }
+
+ @Test
+ public void federatedQuantile3CP() { federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, 0.75); }
+
+ @Test
+ public void federatedMedianCP() { federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME2, -1); }
+
+ @Test
+ public void federatedIQMCP() { federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, -1); }
+
+ @Test
+ public void federatedQuantilesCP() { federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, -1); }
+
+ @Test
+ @Ignore
+ public void federatedQuantile1SP() { federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, 0.25); }
+
+ @Test
+ @Ignore
+ public void federatedQuantile2SP() { federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, 0.5); }
+
+ @Test
+ @Ignore
+ public void federatedQuantile3SP() { federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, 0.75); }
+
+ @Test
+ @Ignore
+ public void federatedMedianSP() { federatedQuartile(Types.ExecMode.SPARK, TEST_NAME2, -1); }
+
+ @Test
+ @Ignore
+ public void federatedIQMSP() { federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, -1); }
+
+ @Test
+ @Ignore
+ public void federatedQuantilesSP() { federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, -1); }
+
+
+
+
+
+ public void federatedQuartile(Types.ExecMode execMode, String TEST_NAME, double p) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ double[][] X1 = getRandomMatrix(rows, 1, 1, 5, 1, 3);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1, blocksize, rows);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1);
+
+ // we need the reference file to not be written to hdfs, so we get the correct format
+ rtplatform = Types.ExecMode.SINGLE_NODE;
+ // Run reference dml script with normal matrix for Row/Col
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-explain", "-stats", "100", "-args", input("X1"), expected("S"), String.valueOf(p)};
+ runTest(true, false, null, -1);
+
+ // reference file should not be written to hdfs, so we set platform here
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-explain", "-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "rows=" + rows, "cols=" + 1, "p=" + String.valueOf(p),
+ "out_S=" + output("S")
+ };
+ runTest(true, false, null, -1);
+
+ // compare all sums via files
+ compareResults(1e-9);
+ Assert.assertTrue(heavyHittersContainsString("fed_qsort"));
+ Assert.assertTrue(heavyHittersContainsString("fed_qpick"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+
+ TestUtils.shutdownThreads(t1);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
new file mode 100644
index 0000000..4511c6d
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedQuantileWeightsTest extends AutomatedTestBase {
+
+ private final static String TEST_DIR = "functions/federated/quantile/";
+ private final static String TEST_NAME1 = "FederatedQuantileWeightsTest";
+ private final static String TEST_NAME2 = "FederatedMedianWeightsTest";
+ private final static String TEST_NAME3 = "FederatedIQMWeightsTest";
+ private final static String TEST_NAME4 = "FederatedQuantilesWeightsTest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedQuantileWeightsTest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ return Arrays.asList(new Object[][] {{1000}});
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S"}));
+ }
+
+ @Test
+ public void federatedQuantile1CP() { federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, 0.25); }
+
+ @Test
+ public void federatedQuantile2CP() { federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, 0.5); }
+
+ @Test
+ public void federatedQuantile3CP() { federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, 0.75); }
+
+ @Test
+ public void federatedMedianCP() { federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME2, -1); }
+
+ @Test
+ public void federatedIQMCP() { federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, -1); }
+
+ @Test
+ public void federatedQuantilesCP() { federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, -1); }
+
+ public void federatedQuartile(Types.ExecMode execMode, String TEST_NAME, double p) {
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ Types.ExecMode platformOld = rtplatform;
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ double[][] X1 = getRandomMatrix(rows, 1, 1, 5, 1, 3);
+
+ MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1, blocksize, rows);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+
+ double[][] W = getRandomMatrix(rows, 1, 1, 1, 1.0, 1);
+ writeInputMatrixWithMTD("W", W, false);
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorkerThread(port1);
+
+ // we need the reference file to not be written to hdfs, so we get the correct format
+ rtplatform = Types.ExecMode.SINGLE_NODE;
+ // Run reference dml script with normal matrix for Row/Col
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-explain", "-stats", "100", "-args", input("X1"), expected("S"), String.valueOf(p), input("W")};
+ runTest(true, false, null, -1);
+
+ // reference file should not be written to hdfs, so we set platform here
+ rtplatform = execMode;
+ if(rtplatform == Types.ExecMode.SPARK) {
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ }
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-explain", "-stats", "100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "rows=" + rows, "cols=" + 1,
+ "p=" + String.valueOf(p), "W=" + input("W"),
+ "out_S=" + output("S")
+ };
+ runTest(true, false, null, -1);
+
+ // compare all sums via files
+ compareResults(1e-9);
+ Assert.assertTrue(heavyHittersContainsString("fed_qsort"));
+ Assert.assertTrue(heavyHittersContainsString("fed_qpick"));
+
+ // check that federated input files are still existing
+ Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+
+ TestUtils.shutdownThreads(t1);
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
index ca745b9..abf37eb 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRCBindTest.java
@@ -19,18 +19,19 @@
package org.apache.sysds.test.functions.federated.primitives;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
+import java.util.Arrays;
+import java.util.Collection;
+
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import java.util.Arrays;
-import java.util.Collection;
-
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class FederatedRCBindTest extends AutomatedTestBase {
@@ -48,7 +49,14 @@ public class FederatedRCBindTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
//TODO add tests and support of aligned blocksized (which is however a special case)
- return Arrays.asList(new Object[][] {{1, 1001}, {10, 100}, {100, 10}, {1001, 1}, {10, 2001}, {2001, 10}});
+ return Arrays.asList(new Object[][] {
+ // {1, 1001},
+ // {10, 100},
+ {100, 10},
+ // {1001, 1},
+ // {10, 2001},
+ // {2001, 10}
+ });
}
@Override
@@ -67,6 +75,7 @@ public class FederatedRCBindTest extends AutomatedTestBase {
}
@Test
+ @Ignore
public void federatedRCBindSP() {
federatedRCBind(Types.ExecMode.SPARK);
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
index 0adcb15..d5f81e9 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
@@ -22,6 +22,8 @@ package org.apache.sysds.test.functions.federated.primitives;
import java.util.Arrays;
import java.util.Collection;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -37,7 +39,7 @@ import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
public class FederatedRightIndexTest extends AutomatedTestBase {
- // private static final Log LOG = LogFactory.getLog(FederatedRightIndexTest.class.getName());
+ private static final Log LOG = LogFactory.getLog(FederatedRightIndexTest.class.getName());
private final static String TEST_NAME1 = "FederatedRightIndexRightTest";
private final static String TEST_NAME2 = "FederatedRightIndexLeftTest";
@@ -64,11 +66,12 @@ public class FederatedRightIndexTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
- {20, 10, 6, 8, true},
+ // {20, 10, 6, 8, true},
{20, 10, 1, 1, true},
{20, 10, 2, 10, true},
// {20, 10, 2, 10, true},
- // {20, 12, 2, 10, false}, {20, 12, 1, 4, false}
+ // {20, 12, 2, 10, false},
+ // {20, 12, 1, 4, false}
});
}
@@ -164,8 +167,7 @@ public class FederatedRightIndexTest extends AutomatedTestBase {
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[] {"-args", input("X1"), input("X2"), input("X3"), input("X4"), String.valueOf(from),
String.valueOf(to), Boolean.toString(rowPartitioned).toUpperCase(), expected("S")};
- // LOG.error(runTest(null));
- runTest(null);
+ LOG.debug(runTest(null));
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
@@ -176,8 +178,8 @@ public class FederatedRightIndexTest extends AutomatedTestBase {
"in_X4=" + TestUtils.federatedAddress(port4, input("X4")), "rows=" + rows, "cols=" + cols, "from=" + from,
"to=" + to, "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
- // LOG.error(runTest(null));
- runTest(null);
+ LOG.debug(runTest(null));
+
// compare via files
compareResults(1e-9);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowColAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowColAggregateTest.java
index b1c3991..7bfbeee 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowColAggregateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowColAggregateTest.java
@@ -60,7 +60,7 @@ public class FederatedRowColAggregateTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(
- new Object[][] {{10, 1000, false},
+ new Object[][] {{10, 1000, false},
//{100, 4, false}, {36, 1000, true}, {1000, 10, true}, {4, 100, true}
});
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
index 04f2828..9c4b6d0 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
@@ -30,6 +30,7 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -53,7 +54,7 @@ public class FederatedSplitTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
- return Arrays.asList(new Object[][] {{152, 12, "TRUE"},{132, 11, "FALSE"}});
+ return Arrays.asList(new Object[][] {{152, 12, "TRUE"}, {132, 11, "FALSE"}});
}
@Override
@@ -108,7 +109,7 @@ public class FederatedSplitTest extends AutomatedTestBase {
// Run actual dml script with federated matrix
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ programArgs = new String[] {"-stats", "100", "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
"X2=" + TestUtils.federatedAddress(port2, input("X2")),
"Y1=" + TestUtils.federatedAddress(port1, input("Y1")),
"Y2=" + TestUtils.federatedAddress(port2, input("Y2")), "r=" + rows, "c=" + cols, "Z=" + output("Z"),
@@ -120,6 +121,16 @@ public class FederatedSplitTest extends AutomatedTestBase {
// compare via files
compareResults(1e-9);
+ if(cont.equals("TRUE"))
+ Assert.assertTrue(heavyHittersContainsString("fed_rightIndex"));
+ else{
+
+ Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
+ // TODO add federated diag operator.
+ // Assert.assertTrue(heavyHittersContainsString("fed_rdiag"));
+
+ }
+
TestUtils.shutdownThreads(t1, t2);
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
index 0eccc8d..99e649f 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedStatisticsTest.java
@@ -59,7 +59,11 @@ public class FederatedStatisticsTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
// rows have to be even and > 1
- return Arrays.asList(new Object[][] {{10000, 10}, {1000, 100}, {2000, 43}});
+ return Arrays.asList(new Object[][] {
+ // {10000, 10},
+ // {1000, 100},
+ {2000, 43}
+ });
}
@Test
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
index 9ce65eb..3d03f7b 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSumTest.java
@@ -19,18 +19,19 @@
package org.apache.sysds.test.functions.federated.primitives;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
+import java.util.Arrays;
+import java.util.Collection;
+
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-
-import java.util.Arrays;
-import java.util.Collection;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
@@ -48,7 +49,11 @@ public class FederatedSumTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
- return Arrays.asList(new Object[][] {{2, 1000}, {10, 100}, {100, 10}, {1000, 1}, {10, 2000}, {2000, 10}});
+ return Arrays.asList(new Object[][] {
+ // {2, 1000}, {10, 100},
+ {100, 10}, {1000, 1},
+ // {10, 2000}, {2000, 10}
+ });
}
@Override
@@ -64,6 +69,7 @@ public class FederatedSumTest extends AutomatedTestBase {
}
@Test
+ @Ignore
public void federatedSumSP() {
federatedSum(Types.ExecMode.SPARK);
}
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/FederatedCentralMomentTest.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/FederatedCentralMomentTest.dml
index 9c473d8..477b7b9 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/FederatedCentralMomentTest.dml
@@ -17,13 +17,11 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
+A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+s = moment(A, $k);
+write(s, $out_S);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/FederatedCentralMomentTestReference.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/FederatedCentralMomentTestReference.dml
index 9c473d8..15c3a7d 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/FederatedCentralMomentTestReference.dml
@@ -17,13 +17,9 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
+A = rbind(read($1), read($2), read($3), read($4));
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+s = moment(A, $6);
+write(s, $5);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/FederatedCorTest.dml
similarity index 57%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/FederatedCorTest.dml
index 9c473d8..b4d8e80 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/FederatedCorTest.dml
@@ -17,13 +17,18 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
+
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
+}
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
+s = cor(A);
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+write(s, $out_S);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/FederatedCorTestReference.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/FederatedCorTestReference.dml
index 9c473d8..8e40e33 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/FederatedCorTestReference.dml
@@ -17,13 +17,10 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
+if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
+s = cor(A);
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+write(s, $5);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/FederatedVarTest.dml
similarity index 57%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/FederatedVarTest.dml
index 9c473d8..88fd6a8 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/FederatedVarTest.dml
@@ -17,13 +17,17 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2),
+ list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
+}
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+s = var(A);
+write(s, $out_S);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/FederatedVarTestReference.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/FederatedVarTestReference.dml
index 9c473d8..36929c9 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/FederatedVarTestReference.dml
@@ -17,13 +17,11 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
+
+if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
+s = var(A);
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+write(s, $5);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedIQRTest.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedIQRTest.dml
index 9c473d8..f50e72e 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedIQRTest.dml
@@ -17,13 +17,8 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, $cols)));
+s = interQuartileMean(A);
+write(s, $out_S);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedIQRTestReference.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedIQRTestReference.dml
index 9c473d8..84fcace 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedIQRTestReference.dml
@@ -17,13 +17,8 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = read($1);
+s = interQuartileMean(A);
+write(s, $2);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTest.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTest.dml
index 9c473d8..99b48c6 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTest.dml
@@ -17,13 +17,9 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, $cols)));
+W = read($W);
+s = interQuartileMean(A, W);
+write(s, $out_S);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTestReference.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTestReference.dml
index 9c473d8..afc9a1f 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTestReference.dml
@@ -17,13 +17,9 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = read($1);
+W = read($4);
+s = interQuartileMean(A, W);
+write(s, $2);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedMedianTest.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedMedianTest.dml
index 9c473d8..22f2504 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedMedianTest.dml
@@ -17,13 +17,8 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, $cols)));
+s = median(A);
+write(s, $out_S);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedMedianTestReference.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedMedianTestReference.dml
index 9c473d8..1544987 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedMedianTestReference.dml
@@ -17,13 +17,8 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = read($1);
+s = median(A);
+write(s, $2);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTest.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTest.dml
index 9c473d8..58ec328 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTest.dml
@@ -17,13 +17,9 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, $cols)));
+W = read($W);
+s = median(A, W);
+write(s, $out_S);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTestReference.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTestReference.dml
index 9c473d8..6b9e3de 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTestReference.dml
@@ -17,13 +17,9 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = read($1);
+W = read($4);
+s = median(A, W);
+write(s, $2);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedQuantileTest.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedQuantileTest.dml
index 9c473d8..1c84330 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedQuantileTest.dml
@@ -17,13 +17,8 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, $cols)));
+s = quantile(A, $p);
+write(s, $out_S);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedQuantileTestReference.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedQuantileTestReference.dml
index 9c473d8..7a1fb36 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedQuantileTestReference.dml
@@ -17,13 +17,8 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = read($1);
+s = quantile (A, $3);
+write(s, $2);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTest.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTest.dml
index 9c473d8..c423a65 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTest.dml
@@ -17,13 +17,9 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, $cols)));
+W = read($W);
+s = quantile(A, W, $p);
+write(s, $out_S);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTestReference.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTestReference.dml
index 9c473d8..6796757 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTestReference.dml
@@ -17,13 +17,9 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = read($1);
+W = read($4);
+s = quantile (A, W, $3);
+write(s, $2);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedQuantilesTest.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedQuantilesTest.dml
index 9c473d8..f5f22fa 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedQuantilesTest.dml
@@ -17,13 +17,11 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, $cols)));
+P = matrix(0.25, 3, 1);
+P[2,1] = 0.5;
+P[3,1] = 0.75;
+s = quantile(X=A, P=P);
+write(s, $out_S);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedQuantilesTestReference.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedQuantilesTestReference.dml
index 9c473d8..96970d0 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedQuantilesTestReference.dml
@@ -17,13 +17,11 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = read($1);
+P = matrix(0.25, 3, 1);
+P[2,1] = 0.5;
+P[3,1] = 0.75;
+s = quantile(X=A, P=P);
+write(s, $2);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTest.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTest.dml
index 9c473d8..86f9611 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTest.dml
@@ -17,13 +17,12 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, $cols)));
+P = matrix(0.25, 3, 1);
+P[2,1] = 0.5;
+P[3,1] = 0.75;
+W = read($W);
+s = quantile(X=A, W=W, P=P);
+write(s, $out_S);
diff --git a/scripts/builtin/dist.dml b/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTestReference.dml
similarity index 70%
copy from scripts/builtin/dist.dml
copy to src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTestReference.dml
index 9c473d8..3c6d7fd 100644
--- a/scripts/builtin/dist.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTestReference.dml
@@ -17,13 +17,12 @@
# specific language governing permissions and limitations
# under the License.
#
+#-------------------------------------------------------------
-# Returns Euclidian distance matrix (distances between N n-dimensional points)
-
-m_dist = function(Matrix[Double] X) return (Matrix[Double] Y) {
- G = X %*% t(X);
- I = matrix(1, rows = nrow(G), cols = ncol(G));
- Y = -2 * (G) + t(I %*% diag(diag(G))) + t(diag(diag(G)) %*% I);
- Y = sqrt(Y);
- Y = replace(target = Y, pattern=0/0, replacement = 0);
-}
\ No newline at end of file
+A = read($1);
+P = matrix(0.25, 3, 1);
+P[2,1] = 0.5;
+P[3,1] = 0.75;
+W = read($W);
+s = quantile(X=A, W=W, P=P);
+write(s, $2);