You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by se...@apache.org on 2022/04/22 11:59:50 UTC
[systemds] branch main updated: [MINOR] Edit Operation ^2 to Compile as Fed Instruction
This is an automated email from the ASF dual-hosted git repository.
sebwrede pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 8468eff6fe [MINOR] Edit Operation ^2 to Compile as Fed Instruction
8468eff6fe is described below
commit 8468eff6fe4a787209006779345c26f9901f64eb
Author: sebwrede <sw...@know-center.at>
AuthorDate: Fri Apr 22 12:27:28 2022 +0200
[MINOR] Edit Operation ^2 to Compile as Fed Instruction
---
src/main/java/org/apache/sysds/lops/Unary.java | 17 ++++++++++++---
.../runtime/instructions/FEDInstructionParser.java | 1 +
.../fedplanning/FederatedMultiplyPlanningTest.java | 15 ++++++++++++-
.../FederatedMultiplyPlanningTest10.dml | 25 ++++++++++++++++++++++
.../FederatedMultiplyPlanningTest10Reference.dml | 24 +++++++++++++++++++++
5 files changed, 78 insertions(+), 4 deletions(-)
diff --git a/src/main/java/org/apache/sysds/lops/Unary.java b/src/main/java/org/apache/sysds/lops/Unary.java
index f0a59fabfd..e68d060e71 100644
--- a/src/main/java/org/apache/sysds/lops/Unary.java
+++ b/src/main/java/org/apache/sysds/lops/Unary.java
@@ -138,6 +138,13 @@ public class Unary extends Lop
|| op==OpOp1.POW2
|| op==OpOp1.MULT2;
}
+
+ private void appendFedOut(StringBuilder sb){
+ if (getExecType() == ExecType.FED){
+ sb.append( OPERAND_DELIMITOR );
+ sb.append( _fedOutput.name() );
+ }
+ }
@Override
public String getInstructions(String input1, String output) {
@@ -158,12 +165,14 @@ public class Unary extends Lop
sb.append( prepOutputOperand(output) );
//num threads for cumulative cp ops
- if( getExecType() == ExecType.CP && isMultiThreadedOp(operation) ) {
+ if( (getExecType() == ExecType.CP || getExecType() == ExecType.FED) && isMultiThreadedOp(operation) ) {
sb.append( OPERAND_DELIMITOR );
sb.append( _numThreads );
sb.append( OPERAND_DELIMITOR );
sb.append( _inplace );
}
+
+ appendFedOut(sb);
return sb.toString();
}
@@ -191,10 +200,12 @@ public class Unary extends Lop
sb.append( OPERAND_DELIMITOR );
sb.append( prepOutputOperand(output));
- if( getExecType() == ExecType.CP ) {
+ if( getExecType() == ExecType.CP || getExecType() == ExecType.FED ) {
sb.append( OPERAND_DELIMITOR );
- sb.append( String.valueOf(_numThreads) );
+ sb.append(_numThreads);
}
+
+ appendFedOut(sb);
return sb.toString();
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 4992426d6a..e983e06f5e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -61,6 +61,7 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "*" , FEDType.Binary );
String2FEDInstructionType.put( "/" , FEDType.Binary );
String2FEDInstructionType.put( "1-*", FEDType.Binary); //special * case
+ String2FEDInstructionType.put( "^2" , FEDType.Binary); //special ^ case
String2FEDInstructionType.put( "max", FEDType.Binary );
String2FEDInstructionType.put( "==", FEDType.Binary);
String2FEDInstructionType.put( "!=", FEDType.Binary);
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
index 6ec10232dd..e8d16f6bcb 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java
@@ -55,8 +55,9 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
private final static String TEST_NAME_7 = "FederatedMultiplyPlanningTest7";
private final static String TEST_NAME_8 = "FederatedMultiplyPlanningTest8";
private final static String TEST_NAME_9 = "FederatedMultiplyPlanningTest9";
+ private final static String TEST_NAME_10 = "FederatedMultiplyPlanningTest10";
private final static String TEST_CLASS_DIR = TEST_DIR + FederatedMultiplyPlanningTest.class.getSimpleName() + "/";
- private final static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, "SystemDS-config-cost-based.xml");
+ private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, "SystemDS-config-cost-based.xml");
private final static int blocksize = 1024;
@Parameterized.Parameter()
@@ -76,6 +77,7 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
addTestConfiguration(TEST_NAME_7, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_7, new String[] {"Z"}));
addTestConfiguration(TEST_NAME_8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"}));
addTestConfiguration(TEST_NAME_9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"}));
+ addTestConfiguration(TEST_NAME_10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"}));
}
@Parameterized.Parameters
@@ -146,6 +148,13 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
federatedTwoMatricesSingleNodeTest(TEST_NAME_9, expectedHeavyHitters);
}
+ @Test
+ public void federatedMultiplyPlanningTest10(){
+ String[] expectedHeavyHitters = new String[]{"fed_fedinit", "fed_^2"};
+ TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, "SystemDS-config-fout.xml");
+ federatedTwoMatricesSingleNodeTest(TEST_NAME_10, expectedHeavyHitters);
+ }
+
private void writeStandardMatrix(String matrixName, long seed){
writeStandardMatrix(matrixName, seed, new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation));
}
@@ -200,6 +209,10 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase {
writeColStandardMatrix("W1", 76, null);
writeColStandardMatrix("W2", 11, null);
}
+ else if ( testName.equals(TEST_NAME_10) ){
+ writeStandardMatrix("X1", 42, null);
+ writeStandardMatrix("X2", 1340, null);
+ }
else {
writeStandardMatrix("X1", 42);
writeStandardMatrix("X2", 1340);
diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest10.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest10.dml
new file mode 100644
index 0000000000..6621a717d3
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest10.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+ ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+Z = X^2
+write(Z, $Z)
diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest10Reference.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest10Reference.dml
new file mode 100644
index 0000000000..2ce47943e0
--- /dev/null
+++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest10Reference.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+Z = X^2
+write(Z, $Z)