You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2022/08/29 09:51:20 UTC
[systemds] branch main updated: [MINOR] Add Matrix Multiplication Chain Test and Fix Runtime Bug
This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new e964be2cca [MINOR] Add Matrix Multiplication Chain Test and Fix Runtime Bug
e964be2cca is described below
commit e964be2cca10a11e357f777a1009dd772f99a5a5
Author: sebwrede <sw...@know-center.at>
AuthorDate: Fri Aug 26 16:24:16 2022 +0200
[MINOR] Add Matrix Multiplication Chain Test and Fix Runtime Bug
Closes #1690.
---
.../fed/AggregateBinaryFEDInstruction.java | 24 ++++++++++++++++++-
.../fedplanning/FederatedMultiplyPlanningTest.java | 12 +++++++++-
.../FederatedMultiplyPlanningTest12.dml | 27 ++++++++++++++++++++++
.../FederatedMultiplyPlanningTest12Reference.dml | 26 +++++++++++++++++++++
4 files changed, 87 insertions(+), 2 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 9340e9fb12..1a8115ee94 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -124,7 +124,13 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
setOutputFedMapping(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
}
else {
- aggregateLocally(mo1.getFedMapping(), mo1.isFederated(FType.PART), ec, fr1, fr2);
+ boolean isDoubleBroadcast = (mo1.isFederated(FType.BROADCAST) && mo2.isFederated(FType.BROADCAST));
+ if (isDoubleBroadcast){
+ aggregateLocallySingleWorker(mo1.getFedMapping(), ec, fr1, fr2);
+ }
+ else{
+ aggregateLocally(mo1.getFedMapping(), false, ec, fr1, fr2);
+ }
}
}
//#2 vector - federated matrix multiplication
@@ -231,4 +237,20 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
ret = FederationUtils.bind(ffr, false);
ec.setMatrixOutput(output.getName(), ret);
}
+
+ private void aggregateLocallySingleWorker(FederationMap fedMap, ExecutionContext ec, FederatedRequest... fr) {
+ //create GET calls on output
+ long callInstID = fr[fr.length - 1].getID();
+ FederatedRequest frG = new FederatedRequest(RequestType.GET_VAR, callInstID);
+ FederatedRequest frC = fedMap.cleanup(getTID(), callInstID);
+ //execute federated operations
+ Future<FederatedResponse>[] ffr = fedMap.execute(getTID(), ArrayUtils.addAll(fr, frG, frC));
+ try {
+ //use only one response (all responses contain the same result)
+ MatrixBlock ret = (MatrixBlock) ffr[0].get().getData()[0];
+ ec.setMatrixOutput(output.getName(), ret);
+ } catch(Exception ex){
+ throw new DMLRuntimeException(ex);
+ }
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 2477bdef85..415cd21178 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -56,6 +56,7 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
private final static String TEST_NAME_9 = "FederatedMultiplyPlanningTest9";
private final static String TEST_NAME_10 = "FederatedMultiplyPlanningTest10";
private final static String TEST_NAME_11 = "FederatedMultiplyPlanningTest11";
+ private final static String TEST_NAME_12 = "FederatedMultiplyPlanningTest12";
private final static String TEST_CLASS_DIR = TEST_DIR + FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, "SystemDS-config-cost-based.xml");
@@ -79,6 +80,7 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
addTestConfiguration(TEST_NAME_9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"}));
addTestConfiguration(TEST_NAME_10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"}));
addTestConfiguration(TEST_NAME_11, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_11, new String[] {"Z"}));
+ addTestConfiguration(TEST_NAME_12, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_12, new String[] {"Z"}));
}
@Parameterized.Parameters
@@ -161,6 +163,14 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
federatedTwoMatricesSingleNodeTest(TEST_NAME_11, expectedHeavyHitters);
}
+ @Test
+ public void federatedMultiplyPlanningTest12(){
+ String[] expectedHeavyHitters = new String[]{"fed_fedinit"};
+ rows = 30;
+ cols = 30;
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_12, expectedHeavyHitters);
+ }
+
private void writeStandardMatrix(String matrixName, long seed){
writeStandardMatrix(matrixName, seed, new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
}
@@ -215,7 +225,7 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
writeColStandardMatrix("W1", 76, null);
writeColStandardMatrix("W2", 11, null);
}
- else if ( testName.equals(TEST_NAME_10) ){
+ else if ( testName.equals(TEST_NAME_10) || testName.equals(TEST_NAME_12) ){
writeStandardMatrix("X1", 42, null);
writeStandardMatrix("X2", 1340, null);
}
diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12.dml
new file mode 100644
index 0000000000..3ef9909e68
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+z0 = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+z1 = z0 %*% z0
+z2 = z1 %*% z1
+print(toString(z2))
+write(z2, $Z)
diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12Reference.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12Reference.dml
new file mode 100644
index 0000000000..652172c2a8
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12Reference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+z0 = rbind(read($X1), read($X2))
+z1 = z0 %*% z0
+z2 = z1 %*% z1
+print(toString(z2))
+write(z2, $Z)