You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2022/08/29 09:51:20 UTC

[systemds] branch main updated: [MINOR] Add Matrix Multiplication Chain Test and Fix Runtime Bug

This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new e964be2cca [MINOR] Add Matrix Multiplication Chain Test and Fix Runtime Bug
e964be2cca is described below

commit e964be2cca10a11e357f777a1009dd772f99a5a5
Author: sebwrede <sw...@know-center.at>
AuthorDate: Fri Aug 26 16:24:16 2022 +0200

    [MINOR] Add Matrix Multiplication Chain Test and Fix Runtime Bug
    
    Closes #1690.
---
 .../fed/AggregateBinaryFEDInstruction.java         | 24 ++++++++++++++++++-
 .../fedplanning/FederatedMultiplyPlanningTest.java | 12 +++++++++-
 .../FederatedMultiplyPlanningTest12.dml            | 27 ++++++++++++++++++++++
 .../FederatedMultiplyPlanningTest12Reference.dml   | 26 +++++++++++++++++++++
 4 files changed, 87 insertions(+), 2 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 9340e9fb12..1a8115ee94 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -124,7 +124,13 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
 				setOutputFedMapping(mo1.getFedMapping(), mo1, mo2, fr2.getID(), ec);
 			}
 			else {
-				aggregateLocally(mo1.getFedMapping(), mo1.isFederated(FType.PART), ec, fr1, fr2);
+				boolean isDoubleBroadcast = (mo1.isFederated(FType.BROADCAST) && mo2.isFederated(FType.BROADCAST));
+				if (isDoubleBroadcast){
+					aggregateLocallySingleWorker(mo1.getFedMapping(), ec, fr1, fr2);
+				}
+				else{
+					aggregateLocally(mo1.getFedMapping(), false, ec, fr1, fr2);
+				}
 			}
 		}
 		//#2 vector - federated matrix multiplication
@@ -231,4 +237,20 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
 			ret = FederationUtils.bind(ffr, false);
 		ec.setMatrixOutput(output.getName(), ret);
 	}
+
+	private void aggregateLocallySingleWorker(FederationMap fedMap, ExecutionContext ec, FederatedRequest... fr) {
+		//create GET calls on output
+		long callInstID = fr[fr.length - 1].getID();
+		FederatedRequest frG = new FederatedRequest(RequestType.GET_VAR, callInstID);
+		FederatedRequest frC = fedMap.cleanup(getTID(), callInstID);
+		//execute federated operations
+		Future<FederatedResponse>[] ffr = fedMap.execute(getTID(), ArrayUtils.addAll(fr, frG, frC));
+		try {
+			//use only one response (all responses contain the same result)
+			MatrixBlock ret = (MatrixBlock) ffr[0].get().getData()[0];
+			ec.setMatrixOutput(output.getName(), ret);
+		} catch(Exception ex){
+			throw new DMLRuntimeException(ex);
+		}
+	}
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 2477bdef85..415cd21178 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -56,6 +56,7 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
 	private final static String TEST_NAME_9 = "FederatedMultiplyPlanningTest9";
 	private final static String TEST_NAME_10 = "FederatedMultiplyPlanningTest10";
 	private final static String TEST_NAME_11 = "FederatedMultiplyPlanningTest11";
+	private final static String TEST_NAME_12 = "FederatedMultiplyPlanningTest12";
 	private final static String TEST_CLASS_DIR = TEST_DIR + FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
 	private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, "SystemDS-config-cost-based.xml");
 
@@ -79,6 +80,7 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
 		addTestConfiguration(TEST_NAME_9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"}));
 		addTestConfiguration(TEST_NAME_10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"}));
 		addTestConfiguration(TEST_NAME_11, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_11, new String[] {"Z"}));
+		addTestConfiguration(TEST_NAME_12, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_12, new String[] {"Z"}));
 	}
 
 	@Parameterized.Parameters
@@ -161,6 +163,14 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
 		federatedTwoMatricesSingleNodeTest(TEST_NAME_11, expectedHeavyHitters);
 	}
 
+	@Test
+	public void federatedMultiplyPlanningTest12(){
+		String[] expectedHeavyHitters = new String[]{"fed_fedinit"};
+		rows = 30;
+		cols = 30;
+		federatedTwoMatricesSingleNodeTest(TEST_NAME_12, expectedHeavyHitters);
+	}
+
 	private void writeStandardMatrix(String matrixName, long seed){
 		writeStandardMatrix(matrixName, seed, new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
 	}
@@ -215,7 +225,7 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
 			writeColStandardMatrix("W1", 76, null);
 			writeColStandardMatrix("W2", 11, null);
 		}
-		else if ( testName.equals(TEST_NAME_10) ){
+		else if ( testName.equals(TEST_NAME_10) || testName.equals(TEST_NAME_12) ){
 			writeStandardMatrix("X1", 42, null);
 			writeStandardMatrix("X2", 1340, null);
 		}
diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12.dml
new file mode 100644
index 0000000000..3ef9909e68
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+z0 = federated(addresses=list($X1, $X2),
+              ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+z1 = z0 %*% z0
+z2 = z1 %*% z1
+print(toString(z2))
+write(z2, $Z)
diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12Reference.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12Reference.dml
new file mode 100644
index 0000000000..652172c2a8
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest12Reference.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+z0 = rbind(read($X1), read($X2))
+z1 = z0 %*% z0
+z2 = z1 %*% z1
+print(toString(z2))
+write(z2, $Z)