You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2020/11/12 19:57:55 UTC

[systemds] 04/05: [SYSTEMDS-2723] Cast to frame Federated

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 6d19dbaffc3e2ead068474c734ad289141025b5a
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Thu Nov 12 15:01:34 2020 +0100

    [SYSTEMDS-2723] Cast to frame Federated
---
 .../controlprogram/federated/FederatedData.java    |   7 ++
 .../instructions/fed/FEDInstructionUtils.java      |  25 ++--
 .../instructions/fed/VariableFEDInstruction.java   |  57 ++++++++++
 .../primitives/FederetedCastToFrameTest.java       | 126 +++++++++++++++++++++
 .../primitives/FederatedCastToFrameTest.dml        |  26 +++++
 .../FederatedCastToFrameTestReference.dml          |  25 ++++
 6 files changed, 259 insertions(+), 7 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index f9702c0..d19d132 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -74,6 +74,13 @@ public class FederatedData {
 			_allFedSites.add(_address);
 	}
 
+	public FederatedData(Types.DataType dataType, InetSocketAddress address, String filepath, long varID) {
+		_dataType = dataType;
+		_address = address;
+		_filepath = filepath;
+		_varID = varID;
+	}
+
 	public InetSocketAddress getAddress() {
 		return _address;
 	}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 2edc5f2..68e1cee 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -36,6 +36,8 @@ import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOperationCode;
 import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
@@ -82,13 +84,15 @@ public class FEDInstructionUtils {
 			if( mo.isFederated() )
 				fedinst = TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
 		}
-		else if (inst instanceof AggregateUnaryCPInstruction) {
-			AggregateUnaryCPInstruction instruction = (AggregateUnaryCPInstruction) inst;
-			if( instruction.input1.isMatrix() && ec.containsVariable(instruction.input1) ) {
-				MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
-				if (mo1.isFederated() && instruction.getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT){
-					LOG.debug("Federated UnaryAggregate");
-					fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
+		else if(inst instanceof UnaryCPInstruction){
+			if (inst instanceof AggregateUnaryCPInstruction) {
+				AggregateUnaryCPInstruction instruction = (AggregateUnaryCPInstruction) inst;
+				if( instruction.input1.isMatrix() && ec.containsVariable(instruction.input1) ) {
+					MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
+					if (mo1.isFederated() && instruction.getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT){
+						LOG.debug("Federated UnaryAggregate");
+						fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
+					}
 				}
 			}
 		}
@@ -141,12 +145,19 @@ public class FEDInstructionUtils {
 			VariableCPInstruction ins = (VariableCPInstruction) inst;
 
 			if(ins.getVariableOpcode() == VariableOperationCode.Write 
+				&& ins.getInput1().isMatrix()
 				&& ins.getInput3().getName().contains("federated")){
 				fedinst = VariableFEDInstruction.parseInstruction(ins);
 			}
+			else if(ins.getVariableOpcode() == VariableOperationCode.CastAsFrameVariable 
+				&& ins.getInput1().isMatrix() 
+				&& ec.getCacheableData(ins.getInput1()).isFederated()){
+				fedinst = VariableFEDInstruction.parseInstruction(ins);
+			}
 
 		}
 
+
 		
 		//set thread id for federated context management
 		if( fedinst != null ) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
index 91efc2c..7d39e9d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
@@ -19,11 +19,25 @@
 
 package org.apache.sysds.runtime.instructions.fed;
 
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOperationCode;
 import org.apache.sysds.runtime.lineage.LineageItem;
@@ -52,6 +66,13 @@ public class VariableFEDInstruction extends FEDInstruction implements LineageTra
 				processWriteInstruction(ec);
 				break;
 
+			case CastAsMatrixVariable:
+				processCastAsMatrixVariableInstruction(ec);
+				break;
+			case CastAsFrameVariable:
+				processCastAsFrameVariableInstruction(ec);
+				break;
+
 			default:
 				throw new DMLRuntimeException("Unsupported Opcode for federated Variable Instruction : " + opcode);
 		}
@@ -66,6 +87,42 @@ public class VariableFEDInstruction extends FEDInstruction implements LineageTra
 		_in.processInstruction(ec);
 	}
 
+	private void processCastAsMatrixVariableInstruction(ExecutionContext ec){
+		LOG.error("Not Implemented");
+		throw new DMLRuntimeException("Not Implemented Cast as Matrix");
+
+	}
+
+	private void processCastAsFrameVariableInstruction(ExecutionContext ec){
+
+		MatrixObject mo1 = ec.getMatrixObject(_in.getInput1());
+		
+		if( !mo1.isFederated() )
+			throw new DMLRuntimeException("Federated Reorg: "
+				+ "Federated input expected, but invoked w/ "+mo1.isFederated());
+	
+		//execute transpose at federated site
+		FederatedRequest fr1 = FederationUtils.callInstruction(_in.getInstructionString(), _in.getOutput(),
+			new CPOperand[]{_in.getInput1()}, new long[]{mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+		
+		//drive output federated mapping
+		FrameObject out = ec.getFrameObject(_in.getOutput());
+		out.getDataCharacteristics().set(mo1.getNumColumns(),
+			mo1.getNumRows(), (int)mo1.getBlocksize(), mo1.getNnz());
+		FederationMap outMap =  mo1.getFedMapping().copyWithNewID(fr1.getID());
+		Map<FederatedRange, FederatedData> newMap = new HashMap<>();
+		for(Map.Entry<FederatedRange, FederatedData> pair : outMap.getFedMapping().entrySet()){
+			FederatedData om = pair.getValue();
+			FederatedData nf = new FederatedData(Types.DataType.FRAME, om.getAddress(),om.getFilepath(),om.getVarID());
+			newMap.put(pair.getKey(), nf);
+		}
+		ValueType[] schema = new ValueType[(int)mo1.getDataCharacteristics().getCols()];
+		Arrays.fill(schema, ValueType.FP64);
+		out.setSchema(schema);
+		out.setFedMapping(outMap);
+	}
+
 	@Override
 	public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
 		return _in.getLineageItem(ec);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java
new file mode 100644
index 0000000..bbef96e
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederetedCastToFrameTest extends AutomatedTestBase {
+	private static final Log LOG = LogFactory.getLog(FederetedCastToFrameTest.class.getName());
+
+	private final static String TEST_DIR = "functions/federated/primitives/";
+	private final static String TEST_NAME = "FederatedCastToFrameTest";
+	private final static String TEST_CLASS_DIR = TEST_DIR + FederetedCastToFrameTest.class.getSimpleName() + "/";
+
+	private final static int blocksize = 1024;
+	@Parameterized.Parameter()
+	public int rows;
+	@Parameterized.Parameter(1)
+	public int cols;
+
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+	}
+
+	@Parameterized.Parameters
+	public static Collection<Object[]> data() {
+		// rows have to be even and > 1
+		return Arrays.asList(new Object[][] {{10, 32}});
+	}
+
+	@Test
+	public void federatedMultiplyCP() {
+		federatedMultiply(Types.ExecMode.SINGLE_NODE);
+	}
+
+	@Test
+	@Ignore
+	public void federatedMultiplySP() {
+		// TODO Fix me Spark execution error
+		federatedMultiply(Types.ExecMode.SPARK);
+	}
+
+	public void federatedMultiply(Types.ExecMode execMode) {
+		boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+		Types.ExecMode platformOld = rtplatform;
+		rtplatform = execMode;
+		if(rtplatform == Types.ExecMode.SPARK) {
+			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+		}
+
+		getAndLoadTestConfiguration(TEST_NAME);
+		String HOME = SCRIPT_DIR + TEST_DIR;
+
+		// write input matrices
+		int halfRows = rows / 2;
+		// We have two matrices handled by a single federated worker
+		double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+		double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+
+		writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+		writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+
+		int port1 = getRandomAvailablePort();
+		int port2 = getRandomAvailablePort();
+		Thread t1 = startLocalFedWorkerThread(port1);
+		Thread t2 = startLocalFedWorkerThread(port2);
+
+		TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+		loadTestConfiguration(config);
+
+		// Run reference dml script with normal matrix
+		fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+		programArgs = new String[] {"-nvargs", "X1=" + input("X1"), "X2=" + input("X2")};
+		String out = runTest(null).toString().split("SystemDS Statistics:")[0];
+
+		// Run actual dml script with federated matrix
+		fullDMLScriptName = HOME + TEST_NAME + ".dml";
+		programArgs = new String[] {"-stats", "100", "-nvargs", "X1=" + TestUtils.federatedAddress(port1, input("X1")),
+			"X2=" + TestUtils.federatedAddress(port2, input("X2")), "r=" + rows, "c=" + cols};
+		String fedOut = runTest(null).toString();
+
+		LOG.error(fedOut);
+		fedOut = fedOut.split("SystemDS Statistics:")[0];
+		Assert.assertTrue("Equal Printed Output", out.equals(fedOut));
+		Assert.assertTrue("Contains federated Cast to frame", heavyHittersContainsString("fed_castdtf"));
+		TestUtils.shutdownThreads(t1, t2);
+		
+		rtplatform = platformOld;
+		DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+	}
+}
diff --git a/src/test/scripts/functions/federated/primitives/FederatedCastToFrameTest.dml b/src/test/scripts/functions/federated/primitives/FederatedCastToFrameTest.dml
new file mode 100644
index 0000000..6efd3f4
--- /dev/null
+++ b/src/test/scripts/functions/federated/primitives/FederatedCastToFrameTest.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($X1, $X2),
+    ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+
+Z = as.frame(X)
+print(toString(Z[1]))
diff --git a/src/test/scripts/functions/federated/primitives/FederatedCastToFrameTestReference.dml b/src/test/scripts/functions/federated/primitives/FederatedCastToFrameTestReference.dml
new file mode 100644
index 0000000..919e309
--- /dev/null
+++ b/src/test/scripts/functions/federated/primitives/FederatedCastToFrameTestReference.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+
+Z = as.frame(X) 
+print(toString(Z[1]))