You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/08/15 12:40:13 UTC
[systemds] 02/02: [SYSTEMDS-2620] Federated tsmm operations (e.g.,
PCA, lmDS, cor)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
commit a3e3ea949c6af02914356c430756477b948965ce
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Aug 15 14:39:45 2020 +0200
[SYSTEMDS-2620] Federated tsmm operations (e.g., PCA, lmDS, cor)
* Federated tsmm: support for federated tsmm left over row-partioned
federated matrices.
* Performance: aggAdd (e.g., in ba+*, uack+, and tsmm) via nary instead
of binary operations.
---
.../controlprogram/federated/FederationUtils.java | 16 ++-
.../fed/ComputationFEDInstruction.java | 10 +-
.../runtime/instructions/fed/FEDInstruction.java | 3 +-
.../instructions/fed/FEDInstructionUtils.java | 6 +
.../instructions/fed/TsmmFEDInstruction.java | 82 +++++++++++++
.../test/functions/federated/FederatedPCATest.java | 133 +++++++++++++++++++++
.../functions/federated/FederatedPCATest.dml | 25 ++++
.../federated/FederatedPCATestReference.dml | 24 ++++
8 files changed, 283 insertions(+), 16 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index ab0b3aa..f2c8227 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -29,13 +29,13 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
-import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
-import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
public class FederationUtils {
private static final IDSequence _idSeq = new IDSequence();
@@ -58,13 +58,11 @@ public class FederationUtils {
public static MatrixBlock aggAdd(Future<FederatedResponse>[] ffr) {
try {
- BinaryOperator bop = InstructionUtils.parseBinaryOperator("+");
- MatrixBlock ret = (MatrixBlock) (ffr[0].get().getData()[0]);
- for (int i=1; i<ffr.length; i++) {
- MatrixBlock tmp = (MatrixBlock) (ffr[i].get().getData()[0]);
- ret.binaryOperationsInPlace(bop, tmp);
- }
- return ret;
+ SimpleOperator op = new SimpleOperator(Plus.getPlusFnObject());
+ MatrixBlock[] in = new MatrixBlock[ffr.length];
+ for(int i=0; i<ffr.length; i++)
+ in[i] = (MatrixBlock) ffr[i].get().getData()[0];
+ return MatrixBlock.naryOperations(op, in, new ScalarObject[0], new MatrixBlock());
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
index 9d972f4..ccaec24 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
@@ -37,9 +37,8 @@ public abstract class ComputationFEDInstruction extends FEDInstruction implement
public final CPOperand output;
public final CPOperand input1, input2, input3;
- protected ComputationFEDInstruction(FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out,
- String opcode,
- String istr) {
+ protected ComputationFEDInstruction(FEDType type, Operator op,
+ CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
super(type, op, opcode, istr);
input1 = in1;
input2 = in2;
@@ -47,9 +46,8 @@ public abstract class ComputationFEDInstruction extends FEDInstruction implement
output = out;
}
- protected ComputationFEDInstruction(FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3,
- CPOperand out,
- String opcode, String istr) {
+ protected ComputationFEDInstruction(FEDType type, Operator op,
+ CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
super(type, op, opcode, istr);
input1 = in1;
input2 = in2;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index f2d0791..d6bd388 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -32,7 +32,8 @@ public abstract class FEDInstruction extends Instruction {
Append,
Binary,
Init,
- MultiReturnParameterizedBuiltin
+ MultiReturnParameterizedBuiltin,
+ Tsmm,
}
protected final FEDType _fedType;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index d639baa..5f97350 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -76,6 +76,12 @@ public class FEDInstructionUtils {
}
}
}
+ else if( inst instanceof MMTSJCPInstruction ) {
+ MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
+ MatrixObject mo = ec.getMatrixObject(linst.input1);
+ if( mo.isFederated() )
+ return TsmmFEDInstruction.parseInstruction(linst.toString());
+ }
return inst;
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
new file mode 100644
index 0000000..a3061ed
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.lops.MMTSJ.MMTSJType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import java.util.concurrent.Future;
+
+public class TsmmFEDInstruction extends BinaryFEDInstruction {
+ private final MMTSJType _type;
+ @SuppressWarnings("unused")
+ private final int _numThreads;
+
+ public TsmmFEDInstruction(CPOperand in, CPOperand out, MMTSJType type, int k, String opcode, String istr) {
+ super(FEDType.Tsmm, null, in, null, out, opcode, istr);
+ _type = type;
+ _numThreads = k;
+ }
+
+ public static TsmmFEDInstruction parseInstruction(String str) {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+ if(!opcode.equalsIgnoreCase("tsmm"))
+ throw new DMLRuntimeException("TsmmFedInstruction.parseInstruction():: Unknown opcode " + opcode);
+
+ InstructionUtils.checkNumFields(parts, 4);
+ CPOperand in = new CPOperand(parts[1]);
+ CPOperand out = new CPOperand(parts[2]);
+ MMTSJType type = MMTSJType.valueOf(parts[3]);
+ int k = Integer.parseInt(parts[4]);
+ return new TsmmFEDInstruction(in, out, type, k, opcode, str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ MatrixObject mo1 = ec.getMatrixObject(input1);
+
+ if(mo1.isFederated() && _type.isLeft()) { // left tsmm
+ //construct commands: fed tsmm, retrieve results
+ FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+ new CPOperand[]{input1}, new long[]{mo1.getFedMapping().getID()});
+ FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
+
+ //execute federated operations and aggregate
+ Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(fr1, fr2);
+ MatrixBlock ret = FederationUtils.aggAdd(tmp);
+ mo1.getFedMapping().cleanup(fr1.getID());
+ ec.setMatrixOutput(output.getName(), ret);
+ }
+ else { //other combinations
+ throw new DMLRuntimeException("Federated Tsmm not supported with the "
+ + "following federated objects: "+mo1.isFederated()+" "+_fedType);
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
new file mode 100644
index 0000000..29826f8
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedPCATest extends AutomatedTestBase {
+
+ private final static String TEST_DIR = "functions/federated/";
+ private final static String TEST_NAME = "FederatedPCATest";
+ private final static String TEST_CLASS_DIR = TEST_DIR + FederatedPCATest.class.getSimpleName() + "/";
+
+ private final static int blocksize = 1024;
+ @Parameterized.Parameter()
+ public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public boolean scaleAndShift;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+ }
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> data() {
+ // rows have to be even and > 1
+ return Arrays.asList(new Object[][] {
+ {10000, 10, false}, {2000, 50, false}, {1000, 100, false},
+ //TODO support for federated uacmean, uacvar
+ //{10000, 10, true}, {2000, 50, true}, {1000, 100, true}
+ });
+ }
+
+ @Test
+ public void federatedPCASinglenode() {
+ federatedL2SVM(Types.ExecMode.SINGLE_NODE);
+ }
+
+ @Test
+ public void federatedPCAHybrid() {
+ federatedL2SVM(Types.ExecMode.HYBRID);
+ }
+
+ public void federatedL2SVM(Types.ExecMode execMode) {
+ ExecMode platformOld = setExecMode(execMode);
+
+ getAndLoadTestConfiguration(TEST_NAME);
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ // write input matrices
+ int halfRows = rows / 2;
+ // We have two matrices handled by a single federated worker
+ double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 3);
+ double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 7);
+ writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+ writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+
+ // empty script name because we don't execute any script, just start the worker
+ fullDMLScriptName = "";
+ int port1 = getRandomAvailablePort();
+ int port2 = getRandomAvailablePort();
+ Thread t1 = startLocalFedWorker(port1);
+ Thread t2 = startLocalFedWorker(port2);
+
+ TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+ loadTestConfiguration(config);
+ setOutputBuffering(false);
+
+ // Run reference dml script with normal matrix
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+ programArgs = new String[] {"-args", input("X1"), input("X2"),
+ String.valueOf(scaleAndShift).toUpperCase(), expected("Z")};
+ runTest(true, false, null, -1);
+
+ // Run actual dml script with federated matrix
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-stats",
+ "-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
+ "scaleAndShift=" + String.valueOf(scaleAndShift).toUpperCase(), "out=" + output("Z")};
+ runTest(true, false, null, -1);
+
+ // compare via files
+ compareResults(1e-9);
+ TestUtils.shutdownThreads(t1, t2);
+
+ // check for federated operations
+ Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
+ Assert.assertTrue(heavyHittersContainsString("fed_tsmm"));
+ if( scaleAndShift ) {
+ Assert.assertTrue(heavyHittersContainsString("fed_uacmean"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uacvar"));
+ }
+
+ resetExecMode(platformOld);
+ }
+}
diff --git a/src/test/scripts/functions/federated/FederatedPCATest.dml b/src/test/scripts/functions/federated/FederatedPCATest.dml
new file mode 100644
index 0000000..b235d44
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedPCATest.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+ ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), list($rows, $cols)))
+[X2,M] = pca(X=X, K=2, scale=$scaleAndShift, center=$scaleAndShift)
+write(X2, $out)
diff --git a/src/test/scripts/functions/federated/FederatedPCATestReference.dml b/src/test/scripts/functions/federated/FederatedPCATestReference.dml
new file mode 100644
index 0000000..0b17fe0
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedPCATestReference.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2))
+[X2,M] = pca(X=X, K=2, scale=$3, center=$3)
+write(X2, $4)