You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/05/18 18:47:01 UTC
[systemds] branch master updated: [SYSTEMDS-2922] Federated codegen
multi-aggregate operations
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 7325220 [SYSTEMDS-2922] Federated codegen multi-aggregate operations
7325220 is described below
commit 7325220d6581b37b5fd94bba7f782f2898e078e4
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Tue May 18 20:44:38 2021 +0200
[SYSTEMDS-2922] Federated codegen multi-aggregate operations
Closes #1277.
---
.../instructions/fed/FEDInstructionUtils.java | 9 +-
.../instructions/fed/SpoofFEDInstruction.java | 21 ++++-
.../codegen/FederatedCellwiseTmplTest.java | 21 +++--
...mplTest.java => FederatedMultiAggTmplTest.java} | 83 +++++++++----------
.../codegen/FederatedRowwiseTmplTest.java | 21 +++--
.../codegen/FederatedMultiAggTmplTest.dml | 95 ++++++++++++++++++++++
.../codegen/FederatedMultiAggTmplTestReference.dml | 93 +++++++++++++++++++++
7 files changed, 277 insertions(+), 66 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 4b38250..1ff7fbf 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.instructions.fed;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
+import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
@@ -228,8 +229,9 @@ public class FEDInstructionUtils {
else if(inst instanceof SpoofCPInstruction) {
SpoofCPInstruction instruction = (SpoofCPInstruction) inst;
Class<?> scla = instruction.getOperatorClass().getSuperclass();
- if( (scla == SpoofCellwise.class && instruction.isFederated(ec))
- || (scla == SpoofRowwise.class&& instruction.isFederated(ec, FType.ROW))) {
+ if(((scla == SpoofCellwise.class || scla == SpoofMultiAggregate.class)
+ && instruction.isFederated(ec))
+ || (scla == SpoofRowwise.class && instruction.isFederated(ec, FType.ROW))) {
fedinst = SpoofFEDInstruction.parseInstruction(instruction.getInstructionString());
}
}
@@ -337,7 +339,8 @@ public class FEDInstructionUtils {
else if(inst instanceof SpoofSPInstruction) {
SpoofSPInstruction instruction = (SpoofSPInstruction) inst;
Class<?> scla = instruction.getOperatorClass().getSuperclass();
- if( (scla == SpoofCellwise.class && instruction.isFederated(ec))
+ if(((scla == SpoofCellwise.class || scla == SpoofMultiAggregate.class)
+ && instruction.isFederated(ec))
|| (scla == SpoofRowwise.class && instruction.isFederated(ec, FType.ROW))) {
fedinst = SpoofFEDInstruction.parseInstruction(inst.getInstructionString());
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index 2ceada6..01d6051 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -26,6 +26,7 @@ import org.apache.sysds.runtime.codegen.SpoofCellwise.AggOp;
import org.apache.sysds.runtime.codegen.SpoofCellwise.CellType;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.codegen.SpoofRowwise.RowType;
+import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -40,6 +41,7 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import java.util.ArrayList;
@@ -153,11 +155,17 @@ public class SpoofFEDInstruction extends FEDInstruction
setOutputCellwise(ec, response, fedMap);
else if(_op.getClass().getSuperclass() == SpoofRowwise.class)
setOutputRowwise(ec, response, fedMap);
+
+ else if(_op.getClass().getSuperclass() == SpoofMultiAggregate.class)
+ setOutputMultiAgg(ec, response, fedMap);
else
- throw new DMLRuntimeException("Federated code generation only supported for cellwise and rowwise templates.");
+ throw new DMLRuntimeException("Federated code generation only supported for cellwise, rowwise, and multiaggregate templates.");
}
private static boolean needsBroadcastSliced(FederationMap fedMap, long rowNum, long colNum) {
+ if(rowNum == fedMap.getMaxIndexInRange(0) && colNum == fedMap.getMaxIndexInRange(1))
+ return true;
+
if(fedMap.getType() == FType.ROW) {
return (rowNum == fedMap.getMaxIndexInRange(0) && (colNum == 1 || colNum == fedMap.getSize()))
|| (colNum > 1 && rowNum == fedMap.getSize());
@@ -291,4 +299,15 @@ public class SpoofFEDInstruction extends FEDInstruction
throw new DMLRuntimeException("AggregationType not supported yet.");
}
}
+
+ private void setOutputMultiAgg(ExecutionContext ec, Future<FederatedResponse>[] response, FederationMap fedMap)
+ {
+ MatrixBlock[] partRes = FederationUtils.getResults(response);
+ SpoofCellwise.AggOp[] aggOps = ((SpoofMultiAggregate)_op).getAggOps();
+ for(int counter = 1; counter < partRes.length; counter++) {
+ SpoofMultiAggregate.aggregatePartialResults(aggOps, partRes[0], partRes[counter]);
+ }
+ ec.setMatrixOutput(_output.getName(), partRes[0]);
+ }
+
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
index 5ea32bc..2f6e5e4 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedCellwiseTmplTest.java
@@ -29,6 +29,7 @@ import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.BeforeClass;
+import org.junit.Ignore;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.Test;
@@ -114,15 +115,17 @@ public class FederatedCellwiseTmplTest extends AutomatedTestBase
TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
}
-// @Test
-// public void federatedCodegenCellwiseSingleNode() {
-// testFederatedCodegen(ExecMode.SINGLE_NODE);
-// }
-//
-// @Test
-// public void federatedCodegenCellwiseSpark() {
-// testFederatedCodegen(ExecMode.SPARK);
-// }
+ @Test
+ @Ignore
+ public void federatedCodegenCellwiseSingleNode() {
+ testFederatedCodegen(ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void federatedCodegenCellwiseSpark() {
+ testFederatedCodegen(ExecMode.SPARK);
+ }
@Test
public void federatedCodegenCellwiseHybrid() {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java
similarity index 78%
copy from src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
copy to src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java
index 0f63cb3..4526c61 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedMultiAggTmplTest.java
@@ -29,6 +29,7 @@ import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.BeforeClass;
+import org.junit.Ignore;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.Test;
@@ -39,17 +40,17 @@ import java.util.HashMap;
@RunWith(value = Parameterized.class)
@net.jcip.annotations.NotThreadSafe
-public class FederatedRowwiseTmplTest extends AutomatedTestBase
+public class FederatedMultiAggTmplTest extends AutomatedTestBase
{
- private final static String TEST_NAME = "FederatedRowwiseTmplTest";
+ private final static String TEST_NAME = "FederatedMultiAggTmplTest";
private final static String TEST_DIR = "functions/federated/codegen/";
- private final static String TEST_CLASS_DIR = TEST_DIR + FederatedRowwiseTmplTest.class.getSimpleName() + "/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedMultiAggTmplTest.class.getSimpleName() + "/";
private final static String TEST_CONF = "SystemDS-config-codegen.xml";
private final static String OUTPUT_NAME = "Z";
- private final static double TOLERANCE = 1e-13;
+ private final static double TOLERANCE = 1e-14;
private final static int BLOCKSIZE = 1024;
@Parameterized.Parameter()
@@ -68,33 +69,27 @@ public class FederatedRowwiseTmplTest extends AutomatedTestBase
@Parameterized.Parameters
public static Collection<Object[]> data() {
- // rows must be even
+ // rows must be even for row partitioned X
+ // cols must be even for col partitioned X
return Arrays.asList(new Object[][] {
- // {test_num, rows, cols, row_paritioned}
+ // {test_num, rows, cols, row_partitioned}
// row partitioned
{1, 6, 4, true},
- // {2, 6, 2, true},
+ // {2, 6, 4, true},
{3, 6, 4, true},
- {4, 6, 4, true},
- {10, 150, 10, true},
- {15, 150, 10, true},
- // {20, 1500, 8, true},
- {21, 1500, 8, true},
- {25, 600, 10, true},
- {31, 150, 10, true},
- // {40, 300, 20, true},
- {45, 1500, 100, true},
- {50, 376, 4, true},
-
- // col partitioned (should not create a federated spoof instruction)
- // column partitioned federated data is not supported within federated rowwise templates
+ // {4, 6, 4, true},
+ {5, 6, 4, true},
+ {6, 6, 4, true},
+ {7, 20, 1, true},
+
+ // column partitioned
{1, 6, 4, false},
- {3, 6, 4, false},
- {15, 150, 10, false},
- {25, 600, 10, false},
- {31, 150, 10, false},
- {50, 376, 4, false},
+ {2, 6, 4, false},
+ // {3, 6, 4, false},
+ {4, 6, 4, false},
+ // {5, 6, 4, false},
+ {6, 6, 4, false},
});
}
@@ -103,22 +98,24 @@ public class FederatedRowwiseTmplTest extends AutomatedTestBase
TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
}
- // @Test
- // public void federatedCodegenRowwiseSingleNode() {
- // testFederatedCodegenRowwise(ExecMode.SINGLE_NODE);
- // }
- //
- // @Test
- // public void federatedCodegenRowwiseSpark() {
- // testFederatedCodegenRowwise(ExecMode.SPARK);
- // }
-
@Test
- public void federatedCodegenCellwiseHybrid() {
- testFederatedCodegenRowwise(ExecMode.HYBRID);
+ @Ignore
+ public void federatedCodegenMultiAggSingleNode() {
+ testFederatedCodegenMultiAgg(ExecMode.SINGLE_NODE);
}
-
- private void testFederatedCodegenRowwise(ExecMode exec_mode) {
+
+ @Test
+ @Ignore
+ public void federatedCodegenMultiAggSpark() {
+ testFederatedCodegenMultiAgg(ExecMode.SPARK);
+ }
+
+ @Test
+ public void federatedCodegenMultiAggHybrid() {
+ testFederatedCodegenMultiAgg(ExecMode.HYBRID);
+ }
+
+ private void testFederatedCodegenMultiAgg(ExecMode exec_mode) {
// store the previous platform config to restore it after the test
ExecMode platform_old = setExecMode(exec_mode);
@@ -134,8 +131,8 @@ public class FederatedRowwiseTmplTest extends AutomatedTestBase
// generate dataset
// matrix handled by two federated workers
- double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 0.1, 3);
- double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 0.1, 11);
+ double[][] X1 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 1, 3);
+ double[][] X2 = getRandomMatrix(fed_rows, fed_cols, 0, 1, 1, 7);
writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(fed_rows, fed_cols, BLOCKSIZE, fed_rows * fed_cols));
@@ -177,13 +174,11 @@ public class FederatedRowwiseTmplTest extends AutomatedTestBase
TestUtils.shutdownThreads(thread1, thread2);
// check for federated operations
- if(row_partitioned)
- Assert.assertTrue(heavyHittersContainsSubString("fed_spoofRA"));
+ Assert.assertTrue(heavyHittersContainsSubString("fed_spoofMA"));
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
-
resetExecMode(platform_old);
}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
index 0f63cb3..b4bff76 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/codegen/FederatedRowwiseTmplTest.java
@@ -29,6 +29,7 @@ import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.BeforeClass;
+import org.junit.Ignore;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.Test;
@@ -103,15 +104,17 @@ public class FederatedRowwiseTmplTest extends AutomatedTestBase
TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
}
- // @Test
- // public void federatedCodegenRowwiseSingleNode() {
- // testFederatedCodegenRowwise(ExecMode.SINGLE_NODE);
- // }
- //
- // @Test
- // public void federatedCodegenRowwiseSpark() {
- // testFederatedCodegenRowwise(ExecMode.SPARK);
- // }
+ @Test
+ @Ignore
+ public void federatedCodegenRowwiseSingleNode() {
+ testFederatedCodegenRowwise(ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ @Ignore
+ public void federatedCodegenRowwiseSpark() {
+ testFederatedCodegenRowwise(ExecMode.SPARK);
+ }
@Test
public void federatedCodegenCellwiseHybrid() {
diff --git a/src/test/scripts/functions/federated/codegen/FederatedMultiAggTmplTest.dml b/src/test/scripts/functions/federated/codegen/FederatedMultiAggTmplTest.dml
new file mode 100644
index 0000000..358fdb1
--- /dev/null
+++ b/src/test/scripts/functions/federated/codegen/FederatedMultiAggTmplTest.dml
@@ -0,0 +1,95 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+ X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), list($rows, $cols)));
+}
+else {
+ X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows, $cols / 2), list(0, $cols / 2), list($rows, $cols)));
+}
+
+if(test_num == 1) {
+ # X ... 6x4 matrix
+ r1 = min(X > 0.5);
+ r2 = max(X > 0.5);
+ Z = as.matrix(r1 + r2);
+}
+else if(test_num == 2) {
+ # X ... 6x4 matrix
+ r1 = sum(X > 0.5);
+ r2 = sum((X > 0.5)^2);
+ Z = as.matrix(r1 + r2);
+}
+else if(test_num == 3) {
+ # X ... 6x4 matrix
+
+ #disjoint partitions with shared read
+ r1 = sum(X == 0.7)
+ r2 = sum(X == 0.3)
+ Z = as.matrix(r1 + r2);
+}
+else if(test_num == 4) {
+ # X ... 6x4 matrix
+ Y = matrix(seq(2,25), rows=6, cols=4);
+
+ #disjoint partitions with partial shared reads
+ r1 = sum(X * Y);
+ r2 = sum(X ^ 2);
+ r3 = sum(Y ^ 2);
+ Z = as.matrix(r1 + r2 + r3);
+}
+else if(test_num == 5) {
+ # X ... 6x4 matrix
+ U = matrix(seq(0,23), rows=6, cols=4);
+ V = matrix(seq(2,25), rows=6, cols=4);
+ W = matrix(seq(3,26), rows=6, cols=4);
+
+ #disjoint partitions with transitive partial shared reads
+ r1 = sum(X * U);
+ r2 = sum(V * W);
+ r3 = sum(X * V * W);
+ Z = as.matrix(r1 + r2 + r3);
+}
+else if(test_num == 6) {
+ # X ... 6x4 matrix
+
+ r1 = min(X);
+ r2 = max(X);
+ r3 = sum(X);
+ Z = as.matrix(r1 + r2 + r3);
+}
+else if(test_num == 7) {
+ # X ... 20x1 vector
+ Y = seq(2,21);
+ while(FALSE){}
+
+ r1 = t(X) %*% X;
+ r2 = t(X) %*% Y;
+ r3 = t(Y) %*% Y;
+ Z = r1 + r2 + r3;
+}
+
+write(Z, $out_Z);
diff --git a/src/test/scripts/functions/federated/codegen/FederatedMultiAggTmplTestReference.dml b/src/test/scripts/functions/federated/codegen/FederatedMultiAggTmplTestReference.dml
new file mode 100644
index 0000000..4c76863
--- /dev/null
+++ b/src/test/scripts/functions/federated/codegen/FederatedMultiAggTmplTestReference.dml
@@ -0,0 +1,93 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+test_num = $in_test_num;
+row_part = $in_rp;
+
+if(row_part) {
+ X = rbind(read($in_X1), read($in_X2));
+}
+else {
+ X = cbind(read($in_X1), read($in_X2));
+}
+
+if(test_num == 1) {
+ # X ... 6x4 matrix
+ r1 = min(X > 0.5);
+ r2 = max(X > 0.5);
+ Z = as.matrix(r1 + r2);
+}
+else if(test_num == 2) {
+ # X ... 6x4 matrix
+ r1 = sum(X > 0.5);
+ r2 = sum((X > 0.5)^2);
+ Z = as.matrix(r1 + r2);
+}
+else if(test_num == 3) {
+ # X ... 6x4 matrix
+
+ #disjoint partitions with shared read
+ r1 = sum(X == 0.7)
+ r2 = sum(X == 0.3)
+ Z = as.matrix(r1 + r2);
+}
+else if(test_num == 4) {
+ # X ... 6x4 matrix
+ Y = matrix(seq(2,25), rows=6, cols=4);
+
+ #disjoint partitions with partial shared reads
+ r1 = sum(X * Y);
+ r2 = sum(X ^ 2);
+ r3 = sum(Y ^ 2);
+ Z = as.matrix(r1 + r2 + r3);
+}
+else if(test_num == 5) {
+ # X ... 6x4 matrix
+ U = matrix(seq(0,23), rows=6, cols=4);
+ V = matrix(seq(2,25), rows=6, cols=4);
+ W = matrix(seq(3,26), rows=6, cols=4);
+
+ #disjoint partitions with transitive partial shared reads
+ r1 = sum(X * U);
+ r2 = sum(V * W);
+ r3 = sum(X * V * W);
+ Z = as.matrix(r1 + r2 + r3);
+}
+else if(test_num == 6) {
+ # X ... 6x4 matrix
+
+ r1 = min(X);
+ r2 = max(X);
+ r3 = sum(X);
+ Z = as.matrix(r1 + r2 + r3);
+}
+else if(test_num == 7) {
+ # X ... 20x1 vector
+ Y = seq(2,21);
+ while(FALSE){}
+
+ r1 = t(X) %*% X;
+ r2 = t(X) %*% Y;
+ r3 = t(Y) %*% Y;
+ Z = r1 + r2 + r3;
+}
+
+write(Z, $out_Z);