You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/08/15 12:40:13 UTC

[systemds] 02/02: [SYSTEMDS-2620] Federated tsmm operations (e.g., PCA, lmDS, cor)

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit a3e3ea949c6af02914356c430756477b948965ce
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Aug 15 14:39:45 2020 +0200

    [SYSTEMDS-2620] Federated tsmm operations (e.g., PCA, lmDS, cor)
    
    * Federated tsmm: support for federated tsmm left over row-partioned
    federated matrices.
    * Performance: aggAdd (e.g., in ba+*, uack+, and tsmm) via nary instead
    of binary operations.
---
 .../controlprogram/federated/FederationUtils.java  |  16 ++-
 .../fed/ComputationFEDInstruction.java             |  10 +-
 .../runtime/instructions/fed/FEDInstruction.java   |   3 +-
 .../instructions/fed/FEDInstructionUtils.java      |   6 +
 .../instructions/fed/TsmmFEDInstruction.java       |  82 +++++++++++++
 .../test/functions/federated/FederatedPCATest.java | 133 +++++++++++++++++++++
 .../functions/federated/FederatedPCATest.dml       |  25 ++++
 .../federated/FederatedPCATestReference.dml        |  24 ++++
 8 files changed, 283 insertions(+), 16 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index ab0b3aa..f2c8227 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -29,13 +29,13 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysds.runtime.functionobjects.KahanFunction;
-import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.DoubleObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
-import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
 
 public class FederationUtils {
 	private static final IDSequence _idSeq = new IDSequence();
@@ -58,13 +58,11 @@ public class FederationUtils {
 
 	public static MatrixBlock aggAdd(Future<FederatedResponse>[] ffr) {
 		try {
-			BinaryOperator bop = InstructionUtils.parseBinaryOperator("+");
-			MatrixBlock ret = (MatrixBlock) (ffr[0].get().getData()[0]);
-			for (int i=1; i<ffr.length; i++) {
-				MatrixBlock tmp = (MatrixBlock) (ffr[i].get().getData()[0]);
-				ret.binaryOperationsInPlace(bop, tmp);
-			}
-			return ret;
+			SimpleOperator op = new SimpleOperator(Plus.getPlusFnObject());
+			MatrixBlock[] in = new MatrixBlock[ffr.length];
+			for(int i=0; i<ffr.length; i++)
+				in[i] = (MatrixBlock) ffr[i].get().getData()[0];
+			return MatrixBlock.naryOperations(op, in, new ScalarObject[0], new MatrixBlock());
 		}
 		catch(Exception ex) {
 			throw new DMLRuntimeException(ex);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
index 9d972f4..ccaec24 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ComputationFEDInstruction.java
@@ -37,9 +37,8 @@ public abstract class ComputationFEDInstruction extends FEDInstruction implement
 	public final CPOperand output;
 	public final CPOperand input1, input2, input3;
 	
-	protected ComputationFEDInstruction(FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out,
-			String opcode,
-			String istr) {
+	protected ComputationFEDInstruction(FEDType type, Operator op,
+		CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
 		super(type, op, opcode, istr);
 		input1 = in1;
 		input2 = in2;
@@ -47,9 +46,8 @@ public abstract class ComputationFEDInstruction extends FEDInstruction implement
 		output = out;
 	}
 	
-	protected ComputationFEDInstruction(FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3,
-			CPOperand out,
-			String opcode, String istr) {
+	protected ComputationFEDInstruction(FEDType type, Operator op,
+		CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
 		super(type, op, opcode, istr);
 		input1 = in1;
 		input2 = in2;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index f2d0791..d6bd388 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -32,7 +32,8 @@ public abstract class FEDInstruction extends Instruction {
 		Append,
 		Binary,
 		Init,
-		MultiReturnParameterizedBuiltin
+		MultiReturnParameterizedBuiltin,
+		Tsmm,
 	}
 	
 	protected final FEDType _fedType;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index d639baa..5f97350 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -76,6 +76,12 @@ public class FEDInstructionUtils {
 				}
 			}
 		}
+		else if( inst instanceof MMTSJCPInstruction ) {
+			MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
+			MatrixObject mo = ec.getMatrixObject(linst.input1);
+			if( mo.isFederated() )
+				return TsmmFEDInstruction.parseInstruction(linst.toString());
+		}
 		return inst;
 	}
 	
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
new file mode 100644
index 0000000..a3061ed
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.lops.MMTSJ.MMTSJType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import java.util.concurrent.Future;
+
+public class TsmmFEDInstruction extends BinaryFEDInstruction {
+	private final MMTSJType _type;
+	@SuppressWarnings("unused")
+	private final int _numThreads;
+	
+	public TsmmFEDInstruction(CPOperand in, CPOperand out, MMTSJType type, int k, String opcode, String istr) {
+		super(FEDType.Tsmm, null, in, null, out, opcode, istr);
+		_type = type;
+		_numThreads = k;
+	}
+	
+	public static TsmmFEDInstruction parseInstruction(String str) {
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode = parts[0];
+		if(!opcode.equalsIgnoreCase("tsmm"))
+			throw new DMLRuntimeException("TsmmFedInstruction.parseInstruction():: Unknown opcode " + opcode);
+		
+		InstructionUtils.checkNumFields(parts, 4);
+		CPOperand in = new CPOperand(parts[1]);
+		CPOperand out = new CPOperand(parts[2]);
+		MMTSJType type = MMTSJType.valueOf(parts[3]);
+		int k = Integer.parseInt(parts[4]);
+		return new TsmmFEDInstruction(in, out, type, k, opcode, str);
+	}
+	
+	@Override
+	public void processInstruction(ExecutionContext ec) {
+		MatrixObject mo1 = ec.getMatrixObject(input1);
+		
+		if(mo1.isFederated() && _type.isLeft()) { // left tsmm
+			//construct commands: fed tsmm, retrieve results
+			FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+				new CPOperand[]{input1}, new long[]{mo1.getFedMapping().getID()});
+			FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
+			
+			//execute federated operations and aggregate
+			Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(fr1, fr2);
+			MatrixBlock ret = FederationUtils.aggAdd(tmp);
+			mo1.getFedMapping().cleanup(fr1.getID());
+			ec.setMatrixOutput(output.getName(), ret);
+		}
+		else { //other combinations
+			throw new DMLRuntimeException("Federated Tsmm not supported with the "
+				+ "following federated objects: "+mo1.isFederated()+" "+_fedType);
+		}
+	}
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
new file mode 100644
index 0000000..29826f8
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedPCATest.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedPCATest extends AutomatedTestBase {
+
+	private final static String TEST_DIR = "functions/federated/";
+	private final static String TEST_NAME = "FederatedPCATest";
+	private final static String TEST_CLASS_DIR = TEST_DIR + FederatedPCATest.class.getSimpleName() + "/";
+
+	private final static int blocksize = 1024;
+	@Parameterized.Parameter()
+	public int rows;
+	@Parameterized.Parameter(1)
+	public int cols;
+	@Parameterized.Parameter(2)
+	public boolean scaleAndShift;
+
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+	}
+
+	@Parameterized.Parameters
+	public static Collection<Object[]> data() {
+		// rows have to be even and > 1
+		return Arrays.asList(new Object[][] {
+			{10000, 10, false}, {2000, 50, false}, {1000, 100, false},
+			//TODO support for federated uacmean, uacvar
+			//{10000, 10, true}, {2000, 50, true}, {1000, 100, true}
+		});
+	}
+
+	@Test
+	public void federatedPCASinglenode() {
+		federatedL2SVM(Types.ExecMode.SINGLE_NODE);
+	}
+	
+	@Test
+	public void federatedPCAHybrid() {
+		federatedL2SVM(Types.ExecMode.HYBRID);
+	}
+
+	public void federatedL2SVM(Types.ExecMode execMode) {
+		ExecMode platformOld = setExecMode(execMode);
+
+		getAndLoadTestConfiguration(TEST_NAME);
+		String HOME = SCRIPT_DIR + TEST_DIR;
+
+		// write input matrices
+		int halfRows = rows / 2;
+		// We have two matrices handled by a single federated worker
+		double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 3);
+		double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 7);
+		writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+		writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+
+		// empty script name because we don't execute any script, just start the worker
+		fullDMLScriptName = "";
+		int port1 = getRandomAvailablePort();
+		int port2 = getRandomAvailablePort();
+		Thread t1 = startLocalFedWorker(port1);
+		Thread t2 = startLocalFedWorker(port2);
+
+		TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+		loadTestConfiguration(config);
+		setOutputBuffering(false);
+		
+		// Run reference dml script with normal matrix
+		fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+		programArgs = new String[] {"-args", input("X1"), input("X2"),
+			String.valueOf(scaleAndShift).toUpperCase(), expected("Z")};
+		runTest(true, false, null, -1);
+
+		// Run actual dml script with federated matrix
+		fullDMLScriptName = HOME + TEST_NAME + ".dml";
+		programArgs = new String[] {"-stats",
+			"-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+			"in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
+			"scaleAndShift=" + String.valueOf(scaleAndShift).toUpperCase(), "out=" + output("Z")};
+		runTest(true, false, null, -1);
+
+		// compare via files
+		compareResults(1e-9);
+		TestUtils.shutdownThreads(t1, t2);
+		
+		// check for federated operations
+		Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
+		Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
+		Assert.assertTrue(heavyHittersContainsString("fed_tsmm"));
+		if( scaleAndShift ) {
+			Assert.assertTrue(heavyHittersContainsString("fed_uacmean"));
+			Assert.assertTrue(heavyHittersContainsString("fed_uacvar"));
+		}
+		
+		resetExecMode(platformOld);
+	}
+}
diff --git a/src/test/scripts/functions/federated/FederatedPCATest.dml b/src/test/scripts/functions/federated/FederatedPCATest.dml
new file mode 100644
index 0000000..b235d44
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedPCATest.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), list($rows, $cols)))
+[X2,M] = pca(X=X,  K=2, scale=$scaleAndShift, center=$scaleAndShift)
+write(X2, $out)
diff --git a/src/test/scripts/functions/federated/FederatedPCATestReference.dml b/src/test/scripts/functions/federated/FederatedPCATestReference.dml
new file mode 100644
index 0000000..0b17fe0
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedPCATestReference.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2))
+[X2,M] = pca(X=X,  K=2, scale=$3, center=$3)
+write(X2, $4)