You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/05/04 09:44:11 UTC
[systemds] branch master updated: [SYSTEMDS-2542] Federated
rowProds, colProds
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new c14198b [SYSTEMDS-2542] Federated rowProds, colProds
c14198b is described below
commit c14198b75dcee7c77b4f68ca88807cb00d7ffe97
Author: Olga <ov...@gmail.com>
AuthorDate: Mon Apr 12 18:45:54 2021 +0200
[SYSTEMDS-2542] Federated rowProds, colProds
- Added fed col and row prod
- newlines and formatting
Closes #1248
---
.../controlprogram/federated/FederationUtils.java | 30 ++++++++++++++++++-
.../primitives/FederatedColAggregateTest.java | 35 +++++++++++++++-------
.../primitives/FederatedRowAggregateTest.java | 23 ++++++++++----
...FederatedSumTest.dml => FederatedColProdTest.R} | 24 +++++++--------
...atedRowVarTest.dml => FederatedColProdTest.dml} | 4 +--
...mTest.dml => FederatedColProdTestReference.dml} | 16 +++-------
.../federated/aggregate/FederatedColVarTest.dml | 2 +-
...atedRowVarTest.dml => FederatedRowProdTest.dml} | 4 +--
...mTest.dml => FederatedRowProdTestReference.dml} | 16 +++-------
.../federated/aggregate/FederatedRowVarTest.dml | 2 +-
.../federated/aggregate/FederatedSumTest.dml | 2 +-
11 files changed, 97 insertions(+), 61 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index f569364..94fe0bd 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -203,6 +203,32 @@ public class FederationUtils {
}
}
+ public static MatrixBlock aggProd(Future<FederatedResponse>[] ffr, FederationMap fedMap, AggregateUnaryOperator aop) {
+ try {
+ boolean rowFed = fedMap.getType() == FederationMap.FType.ROW;
+ MatrixBlock ret = rowFed ?
+ new MatrixBlock(ffr.length, (int) fedMap.getFederatedRanges()[0].getEndDims()[1], 1.0) :
+ new MatrixBlock((int) fedMap.getFederatedRanges()[0].getEndDims()[0], ffr.length, 1.0);
+ MatrixBlock res = rowFed ?
+ new MatrixBlock(1, (int) fedMap.getFederatedRanges()[0].getEndDims()[1], 1.0) :
+ new MatrixBlock((int) fedMap.getFederatedRanges()[0].getEndDims()[0], 1, 1.0);
+
+ for(int i = 0; i < ffr.length; i++) {
+ MatrixBlock tmp = (MatrixBlock) ffr[i].get().getData()[0];
+ if(rowFed)
+ ret.copy(i, i, 0, ret.getNumColumns()-1, tmp, true);
+ else
+ ret.copy(0, ret.getNumRows()-1, i, i, tmp, true);
+ }
+
+ LibMatrixAgg.aggregateUnaryMatrix(ret, res, aop);
+ return res;
+ }
+ catch (Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ }
+
public static MatrixBlock aggMinMaxIndex(Future<FederatedResponse>[] ffr, boolean isMin, FederationMap map) {
try {
MatrixBlock prev = (MatrixBlock) ffr[0].get().getData()[0];
@@ -410,6 +436,8 @@ public class FederationUtils {
return aggAdd(ffr);
else if( aop.aggOp.increOp.fn instanceof Mean )
return aggMean(ffr, map);
+ else if(aop.aggOp.increOp.fn instanceof Multiply)
+ return aggProd(ffr, map, aop);
else if (aop.aggOp.increOp.fn instanceof Builtin) {
if ((((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MIN ||
((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAX)) {
@@ -419,7 +447,7 @@ public class FederationUtils {
else if((((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MININDEX)
|| (((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MAXINDEX)) {
boolean isMin = ((Builtin) aop.aggOp.increOp.fn).getBuiltinCode() == BuiltinCode.MININDEX;
- return aggMinMaxIndex(ffr,isMin, map);
+ return aggMinMaxIndex(ffr, isMin, map);
}
else throw new DMLRuntimeException("Unsupported aggregation operator: "
+ aop.aggOp.increOp.fn.getClass().getSimpleName());
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
index 1bdcb5b..870e7c2 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedColAggregateTest.java
@@ -41,6 +41,7 @@ public class FederatedColAggregateTest extends AutomatedTestBase {
private final static String TEST_NAME2 = "FederatedColMeanTest";
private final static String TEST_NAME3 = "FederatedColMaxTest";
private final static String TEST_NAME4 = "FederatedColMinTest";
+ private final static String TEST_NAME5 = "FederatedColProdTest";
private final static String TEST_NAME10 = "FederatedColVarTest";
private final static String TEST_DIR = "functions/federated/aggregate/";
@@ -58,13 +59,13 @@ public class FederatedColAggregateTest extends AutomatedTestBase {
public static Collection<Object[]> data() {
return Arrays.asList(
new Object[][] {
- {10, 1000, false},
+// {10, 1000, false},
{1000, 40, true},
});
}
private enum OpType {
- SUM, MEAN, MAX, MIN, VAR
+ SUM, MEAN, MAX, MIN, VAR, PROD
}
@Override
@@ -75,6 +76,7 @@ public class FederatedColAggregateTest extends AutomatedTestBase {
addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"}));
addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S"}));
addTestConfiguration(TEST_NAME10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME10, new String[] {"S"}));
+ addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5, new String[] {"S"}));
}
@Test
@@ -103,6 +105,11 @@ public class FederatedColAggregateTest extends AutomatedTestBase {
runAggregateOperationTest(OpType.VAR, ExecMode.SINGLE_NODE);
}
+ @Test
+ public void testColProdDenseMatrixCP() {
+ runAggregateOperationTest(OpType.PROD, ExecMode.SINGLE_NODE);
+ }
+
private void runAggregateOperationTest(OpType type, ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -127,6 +134,9 @@ public class FederatedColAggregateTest extends AutomatedTestBase {
case VAR:
TEST_NAME = TEST_NAME10;
break;
+ case PROD:
+ TEST_NAME = TEST_NAME5;
+ break;
}
getAndLoadTestConfiguration(TEST_NAME);
@@ -140,16 +150,16 @@ public class FederatedColAggregateTest extends AutomatedTestBase {
c = cols;
}
- double[][] X1 = getRandomMatrix(r, c, 1, 3, 1, 3);
- double[][] X2 = getRandomMatrix(r, c, 1, 3, 1, 7);
- double[][] X3 = getRandomMatrix(r, c, 1, 3, 1, 8);
- double[][] X4 = getRandomMatrix(r, c, 1, 3, 1, 9);
+ double[][] X1 = getRandomMatrix(r, c, 3, 3, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 3, 3, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 3, 3, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 3, 3, 1, 9);
MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
- writeInputMatrixWithMTD("X1", X1, false, mc);
- writeInputMatrixWithMTD("X2", X2, false, mc);
- writeInputMatrixWithMTD("X3", X3, false, mc);
- writeInputMatrixWithMTD("X4", X4, false, mc);
+ writeInputMatrixWithMTD("X1", X1, true, mc);
+ writeInputMatrixWithMTD("X2", X2, true, mc);
+ writeInputMatrixWithMTD("X3", X3, true, mc);
+ writeInputMatrixWithMTD("X4", X4, true, mc);
// empty script name because we don't execute any script, just start the worker
fullDMLScriptName = "";
@@ -189,7 +199,7 @@ public class FederatedColAggregateTest extends AutomatedTestBase {
runTest(true, false, null, -1);
// compare via files
- compareResults(type == FederatedColAggregateTest.OpType.VAR ? 1e-2 : 1e-9);
+ compareResults((type == FederatedColAggregateTest.OpType.VAR) || (type == OpType.PROD) ? 1e-2 : 1e-9);
String fedInst = "fed_uac";
@@ -209,6 +219,9 @@ public class FederatedColAggregateTest extends AutomatedTestBase {
case VAR:
Assert.assertTrue(heavyHittersContainsString(fedInst.concat("var")));
break;
+ case PROD:
+ Assert.assertTrue(heavyHittersContainsString(fedInst.concat("*")));
+ break;
}
// check that federated input files are still existing
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
index e0a3632..c140dc8 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRowAggregateTest.java
@@ -42,6 +42,7 @@ public class FederatedRowAggregateTest extends AutomatedTestBase {
private final static String TEST_NAME7 = "FederatedRowMaxTest";
private final static String TEST_NAME8 = "FederatedRowMinTest";
private final static String TEST_NAME9 = "FederatedRowVarTest";
+ private final static String TEST_NAME10 = "FederatedRowProdTest";
private final static String TEST_DIR = "functions/federated/aggregate/";
private static final String TEST_CLASS_DIR = TEST_DIR + FederatedRowAggregateTest.class.getSimpleName() + "/";
@@ -64,7 +65,7 @@ public class FederatedRowAggregateTest extends AutomatedTestBase {
}
private enum OpType {
- SUM, MEAN, MAX, MIN, VAR
+ SUM, MEAN, MAX, MIN, VAR, PROD
}
@Override
@@ -75,6 +76,7 @@ public class FederatedRowAggregateTest extends AutomatedTestBase {
addTestConfiguration(TEST_NAME7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME7, new String[] {"S"}));
addTestConfiguration(TEST_NAME8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME8, new String[] {"S"}));
addTestConfiguration(TEST_NAME9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME9, new String[] {"S"}));
+ addTestConfiguration(TEST_NAME10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME10, new String[] {"S"}));
}
@Test
@@ -102,6 +104,11 @@ public class FederatedRowAggregateTest extends AutomatedTestBase {
runAggregateOperationTest(OpType.VAR, ExecMode.SINGLE_NODE);
}
+ @Test
+ public void testRowProdDenseMatrixCP() {
+ runAggregateOperationTest(OpType.PROD, ExecMode.SINGLE_NODE);
+ }
+
private void runAggregateOperationTest(OpType type, ExecMode execMode) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
ExecMode platformOld = rtplatform;
@@ -126,6 +133,9 @@ public class FederatedRowAggregateTest extends AutomatedTestBase {
case VAR:
TEST_NAME = TEST_NAME9;
break;
+ case PROD:
+ TEST_NAME = TEST_NAME10;
+ break;
}
getAndLoadTestConfiguration(TEST_NAME);
@@ -139,10 +149,10 @@ public class FederatedRowAggregateTest extends AutomatedTestBase {
c = cols;
}
- double[][] X1 = getRandomMatrix(r, c, 1, 3, 1, 3);
- double[][] X2 = getRandomMatrix(r, c, 1, 3, 1, 7);
- double[][] X3 = getRandomMatrix(r, c, 1, 3, 1, 8);
- double[][] X4 = getRandomMatrix(r, c, 1, 3, 1, 9);
+ double[][] X1 = getRandomMatrix(r, c, 3, 3, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 3, 3, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 3, 3, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 3, 3, 1, 9);
MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
writeInputMatrixWithMTD("X1", X1, false, mc);
@@ -208,6 +218,9 @@ public class FederatedRowAggregateTest extends AutomatedTestBase {
case VAR:
Assert.assertTrue(heavyHittersContainsString(fedInst.concat("var")));
break;
+ case PROD:
+ Assert.assertTrue(heavyHittersContainsString(fedInst.concat("*")));
+ break;
}
// check that federated input files are still existing
diff --git a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml b/src/test/scripts/functions/federated/aggregate/FederatedColProdTest.R
similarity index 61%
copy from src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
copy to src/test/scripts/functions/federated/aggregate/FederatedColProdTest.R
index 9de439e..95ef9e2 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedColProdTest.R
@@ -18,17 +18,15 @@
# under the License.
#
#-------------------------------------------------------------
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
-if ($rP) {
- A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
- ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
- list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
-} else {
- A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
- ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2),
- list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
-}
-
-s = sum(A);
-
-write(s, $out_S);
\ No newline at end of file
+X1 = as.matrix(readMM(paste(args[1], "X1.mtx", sep="")));
+X2 = as.matrix(readMM(paste(args[1], "X2.mtx", sep="")));
+X3 = as.matrix(readMM(paste(args[1], "X3.mtx", sep="")));
+X4 = as.matrix(readMM(paste(args[1], "X4.mtx", sep="")));
+X = rbind(X1, X2, X3, X4)
+R = colProds(X)
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
diff --git a/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml b/src/test/scripts/functions/federated/aggregate/FederatedColProdTest.dml
similarity index 97%
copy from src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
copy to src/test/scripts/functions/federated/aggregate/FederatedColProdTest.dml
index 8b4a57d..ae90cd7 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedColProdTest.dml
@@ -30,5 +30,5 @@ if ($rP) {
list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
}
-s = rowVars(A);
-write(s, $out_S);
\ No newline at end of file
+s = colProds(A);
+write(s, $out_S);
diff --git a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml b/src/test/scripts/functions/federated/aggregate/FederatedColProdTestReference.dml
similarity index 61%
copy from src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
copy to src/test/scripts/functions/federated/aggregate/FederatedColProdTestReference.dml
index 9de439e..6fa8d53 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedColProdTestReference.dml
@@ -19,16 +19,8 @@
#
#-------------------------------------------------------------
-if ($rP) {
- A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
- ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
- list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
-} else {
- A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
- ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2),
- list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
-}
+if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
-s = sum(A);
-
-write(s, $out_S);
\ No newline at end of file
+s = colProds(A);
+write(s, $5);
diff --git a/src/test/scripts/functions/federated/aggregate/FederatedColVarTest.dml b/src/test/scripts/functions/federated/aggregate/FederatedColVarTest.dml
index 186dc1d..a7ac87f 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedColVarTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedColVarTest.dml
@@ -31,4 +31,4 @@ if ($rP) {
}
s = colVars(A);
-write(s, $out_S);
\ No newline at end of file
+write(s, $out_S);
diff --git a/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml b/src/test/scripts/functions/federated/aggregate/FederatedRowProdTest.dml
similarity index 97%
copy from src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
copy to src/test/scripts/functions/federated/aggregate/FederatedRowProdTest.dml
index 8b4a57d..9d5f11d 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedRowProdTest.dml
@@ -30,5 +30,5 @@ if ($rP) {
list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
}
-s = rowVars(A);
-write(s, $out_S);
\ No newline at end of file
+s = rowProds(A);
+write(s, $out_S);
diff --git a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml b/src/test/scripts/functions/federated/aggregate/FederatedRowProdTestReference.dml
similarity index 61%
copy from src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
copy to src/test/scripts/functions/federated/aggregate/FederatedRowProdTestReference.dml
index 9de439e..b917d13 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedRowProdTestReference.dml
@@ -19,16 +19,8 @@
#
#-------------------------------------------------------------
-if ($rP) {
- A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
- ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), list(2*$rows/4, $cols),
- list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), list($rows, $cols)));
-} else {
- A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
- ranges=list(list(0, 0), list($rows, $cols/4), list(0,$cols/4), list($rows, $cols/2),
- list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), list($rows, $cols)));
-}
+if($6) { A = rbind(read($1), read($2), read($3), read($4)); }
+else { A = cbind(read($1), read($2), read($3), read($4)); }
-s = sum(A);
-
-write(s, $out_S);
\ No newline at end of file
+s = rowProds(A);
+write(s, $5);
diff --git a/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml b/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
index 8b4a57d..bec43a2 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedRowVarTest.dml
@@ -31,4 +31,4 @@ if ($rP) {
}
s = rowVars(A);
-write(s, $out_S);
\ No newline at end of file
+write(s, $out_S);
diff --git a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml b/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
index 9de439e..72a7cd6 100644
--- a/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
+++ b/src/test/scripts/functions/federated/aggregate/FederatedSumTest.dml
@@ -31,4 +31,4 @@ if ($rP) {
s = sum(A);
-write(s, $out_S);
\ No newline at end of file
+write(s, $out_S);