You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/08/18 21:01:28 UTC

[systemds] branch master updated: [SYSTEMDS-2627] Federated mmchain instruction for lmCG, MLogreg, GLM

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 0f698fe  [SYSTEMDS-2627] Federated mmchain instruction for lmCG, MLogreg, GLM
0f698fe is described below

commit 0f698fee39191de324b72cc3e22c609d340b124e
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Tue Aug 18 22:55:29 2020 +0200

    [SYSTEMDS-2627] Federated mmchain instruction for lmCG, MLogreg, GLM
    
    This patch adds a federated mmchain instruction for a common
    matrix-vector multiplication chain as it appears in the inner loop of
    lmCG, Mlogreg, and GLM. It also includes a fix for more robust
    instruction manipulation, and a GLM federated test.
    
    Furthermore, we now use a slightly better approach for deciding between
    conf-only and context spark cluster analysis to avoid unnecessary spark
    context creation in local mode (which sometimes interferes with netty
    port allocation in federated tests).
---
 .../java/org/apache/sysds/lops/MapMultChain.java   |   5 +-
 .../context/SparkExecutionContext.java             |   6 +
 .../controlprogram/federated/FederationUtils.java  |   8 +-
 .../instructions/cp/MMChainCPInstruction.java      |  25 ++--
 .../runtime/instructions/fed/FEDInstruction.java   |   1 +
 .../instructions/fed/FEDInstructionUtils.java      |  31 +++--
 .../instructions/fed/MMChainFEDInstruction.java    | 112 +++++++++++++++++
 .../test/functions/federated/FederatedGLMTest.java | 135 +++++++++++++++++++++
 .../functions/federated/FederatedGLMTest.dml       |  27 +++++
 .../federated/FederatedGLMTestReference.dml        |  25 ++++
 10 files changed, 348 insertions(+), 27 deletions(-)

diff --git a/src/main/java/org/apache/sysds/lops/MapMultChain.java b/src/main/java/org/apache/sysds/lops/MapMultChain.java
index 79d57f7..b45d813 100644
--- a/src/main/java/org/apache/sysds/lops/MapMultChain.java
+++ b/src/main/java/org/apache/sysds/lops/MapMultChain.java
@@ -35,7 +35,10 @@ public class MapMultChain extends Lop
 		XtXv,  //(t(X) %*% (X %*% v))
 		XtwXv, //(t(X) %*% (w * (X %*% v)))
 		XtXvy, //(t(X) %*% ((X %*% v) - y))
-		NONE,
+		NONE;
+		public boolean isWeighted() {
+			return this == XtwXv || this == ChainType.XtXvy;
+		}
 	}
 	
 	private ChainType _chainType = null;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index 65348f1..2be647d 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -1771,6 +1771,12 @@ public class SparkExecutionContext extends ExecutionContext
 				_defaultPar = (defaultPar>1) ? defaultPar : numExecutors * numCoresPerExec;
 				_confOnly &= true;
 			}
+			else if( DMLScript.USE_LOCAL_SPARK_CONFIG ) {
+				//avoid unnecessary spark context creation in local mode (e.g., tests)
+				_numExecutors = 1;
+				_defaultPar = 2;
+				_confOnly &= true;
+			}
 			else {
 				//get default parallelism (total number of executors and cores)
 				//note: spark context provides this information while conf does not
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index c34fa62..429834b 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -52,10 +52,14 @@ public class FederationUtils {
 		//TODO better and safe replacement of operand names --> instruction utils
 		long id = getNextFedDataID();
 		String linst = inst.replace(ExecType.SPARK.name(), ExecType.CP.name());
-		linst = linst.replace(Lop.OPERAND_DELIMITOR+varOldOut.getName(), Lop.OPERAND_DELIMITOR+String.valueOf(id));
+		linst = linst.replace(
+			Lop.OPERAND_DELIMITOR+varOldOut.getName()+Lop.DATATYPE_PREFIX,
+			Lop.OPERAND_DELIMITOR+String.valueOf(id)+Lop.DATATYPE_PREFIX);
 		for(int i=0; i<varOldIn.length; i++)
 			if( varOldIn[i] != null ) {
-				linst = linst.replace(Lop.OPERAND_DELIMITOR+varOldIn[i].getName(), Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i]));
+				linst = linst.replace(
+					Lop.OPERAND_DELIMITOR+varOldIn[i].getName()+Lop.DATATYPE_PREFIX,
+					Lop.OPERAND_DELIMITOR+String.valueOf(varNewIn[i])+Lop.DATATYPE_PREFIX);
 				linst = linst.replace("="+varOldIn[i].getName(), "="+String.valueOf(varNewIn[i])); //parameterized
 			}
 		return new FederatedRequest(RequestType.EXEC_INST, id, linst);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/MMChainCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/MMChainCPInstruction.java
index f540343..dcff65b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/MMChainCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/MMChainCPInstruction.java
@@ -36,31 +36,30 @@ public class MMChainCPInstruction extends UnaryCPInstruction {
 		_type = type;
 		_numThreads = k;
 	}
+	
+	public ChainType getMMChainType() {
+		return _type;
+	}
 
 	public static MMChainCPInstruction parseInstruction ( String str ) {
 		//parse instruction parts (without exec type)
 		String[] parts = InstructionUtils.getInstructionPartsWithValueType( str );
 		InstructionUtils.checkNumFields( parts, 5, 6 );
-	
 		String opcode = parts[0];
 		CPOperand in1 = new CPOperand(parts[1]);
 		CPOperand in2 = new CPOperand(parts[2]);
 		
-		if( parts.length==6 )
-		{
+		if( parts.length==6 ) {
 			CPOperand out= new CPOperand(parts[3]);
 			ChainType type = ChainType.valueOf(parts[4]);
 			int k = Integer.parseInt(parts[5]);
-			
 			return new MMChainCPInstruction(null, in1, in2, null, out, type, k, opcode, str);
 		}
-		else //parts.length==7
-		{
+		else { //parts.length==7
 			CPOperand in3 = new CPOperand(parts[3]);
 			CPOperand out = new CPOperand(parts[4]);
 			ChainType type = ChainType.valueOf(parts[5]);
 			int k = Integer.parseInt(parts[6]);
-			
 			return new MMChainCPInstruction(null, in1, in2, in3, out, type, k, opcode, str);
 		}
 	}
@@ -70,19 +69,15 @@ public class MMChainCPInstruction extends UnaryCPInstruction {
 		//get inputs
 		MatrixBlock X = ec.getMatrixInput(input1.getName());
 		MatrixBlock v = ec.getMatrixInput(input2.getName());
-		MatrixBlock w = (_type==ChainType.XtwXv || _type==ChainType.XtXvy) ? 
-			ec.getMatrixInput(input3.getName()) : null;
+		MatrixBlock w = _type.isWeighted() ? ec.getMatrixInput(input3.getName()) : null;
+		
 		//execute mmchain operation 
-		 MatrixBlock out = X.chainMatrixMultOperations(v, w, new MatrixBlock(), _type, _numThreads);
+		MatrixBlock out = X.chainMatrixMultOperations(v, w, new MatrixBlock(), _type, _numThreads);
+		
 		//set output and release inputs
 		ec.setMatrixOutput(output.getName(), out);
 		ec.releaseMatrixInput(input1.getName(), input2.getName());
 		if( w !=null )
 			ec.releaseMatrixInput(input3.getName());
 	}
-	
-	public ChainType getMMChainType()
-	{
-		return _type;
-	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 6df1b1e..77dedfd 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -35,6 +35,7 @@ public abstract class FEDInstruction extends Instruction {
 		MultiReturnParameterizedBuiltin,
 		ParameterizedBuiltin,
 		Tsmm,
+		MMChain,
 	}
 	
 	protected final FEDType _fedType;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 4325456..bbdaa8e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -45,6 +45,18 @@ public class FEDInstructionUtils {
 				}
 			}
 		}
+		else if( inst instanceof MMChainCPInstruction) {
+			MMChainCPInstruction linst = (MMChainCPInstruction) inst;
+			MatrixObject mo = ec.getMatrixObject(linst.input1);
+			if( mo.isFederated() )
+				fedinst = MMChainFEDInstruction.parseInstruction(linst.getInstructionString());
+		}
+		else if( inst instanceof MMTSJCPInstruction ) {
+			MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
+			MatrixObject mo = ec.getMatrixObject(linst.input1);
+			if( mo.isFederated() )
+				fedinst = TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
+		}
 		else if (inst instanceof AggregateUnaryCPInstruction) {
 			AggregateUnaryCPInstruction instruction = (AggregateUnaryCPInstruction) inst;
 			if( instruction.input1.isMatrix() && ec.containsVariable(instruction.input1) ) {
@@ -77,12 +89,6 @@ public class FEDInstructionUtils {
 				}
 			}
 		}
-		else if( inst instanceof MMTSJCPInstruction ) {
-			MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
-			MatrixObject mo = ec.getMatrixObject(linst.input1);
-			if( mo.isFederated() )
-				fedinst = TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
-		}
 		
 		//set thread id for federated context management
 		if( fedinst != null ) {
@@ -94,13 +100,14 @@ public class FEDInstructionUtils {
 	}
 	
 	public static Instruction checkAndReplaceSP(Instruction inst, ExecutionContext ec) {
+		FEDInstruction fedinst = null;
 		if (inst instanceof MapmmSPInstruction) {
 			// FIXME does not yet work for MV multiplication. SPARK execution mode not supported for federated l2svm
 			MapmmSPInstruction instruction = (MapmmSPInstruction) inst;
 			Data data = ec.getVariable(instruction.input1);
 			if (data instanceof MatrixObject && ((MatrixObject) data).isFederated()) {
 				// TODO correct FED instruction string
-				return new AggregateBinaryFEDInstruction(instruction.getOperator(),
+				fedinst = new AggregateBinaryFEDInstruction(instruction.getOperator(),
 					instruction.input1, instruction.input2, instruction.output, "ba+*", "FED...");
 			}
 		}
@@ -108,7 +115,7 @@ public class FEDInstructionUtils {
 			AggregateUnarySPInstruction instruction = (AggregateUnarySPInstruction) inst;
 			Data data = ec.getVariable(instruction.input1);
 			if (data instanceof MatrixObject && ((MatrixObject) data).isFederated())
-				return AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
+				fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
 		}
 		else if (inst instanceof WriteSPInstruction) {
 			WriteSPInstruction instruction = (WriteSPInstruction) inst;
@@ -124,9 +131,15 @@ public class FEDInstructionUtils {
 			AppendGAlignedSPInstruction instruction = (AppendGAlignedSPInstruction) inst;
 			Data data = ec.getVariable(instruction.input1);
 			if (data instanceof MatrixObject && ((MatrixObject) data).isFederated()) {
-				return AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
+				fedinst = AppendFEDInstruction.parseInstruction(instruction.getInstructionString());
 			}
 		}
+		//set thread id for federated context management
+		if( fedinst != null ) {
+			fedinst.setTID(ec.getTID());
+			return fedinst;
+		}
+		
 		return inst;
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
new file mode 100644
index 0000000..2dee64b
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import org.apache.sysds.lops.MapMultChain.ChainType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import java.util.concurrent.Future;
+
+public class MMChainFEDInstruction extends UnaryFEDInstruction {
+	
+	public MMChainFEDInstruction(CPOperand in1, CPOperand in2, CPOperand in3, 
+		CPOperand out, ChainType type, int k, String opcode, String istr) {
+		super(FEDType.MMChain, null, in1, in2, in3, out, opcode, istr);
+		_type = type;
+	}
+	
+	private final ChainType _type;
+
+	public ChainType getMMChainType() {
+		return _type;
+	}
+
+	public static MMChainFEDInstruction parseInstruction ( String str ) {
+		//parse instruction parts (without exec type)
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType( str );
+		InstructionUtils.checkNumFields( parts, 5, 6 );
+		String opcode = parts[0];
+		CPOperand in1 = new CPOperand(parts[1]);
+		CPOperand in2 = new CPOperand(parts[2]);
+		
+		if( parts.length==6 ) {
+			CPOperand out= new CPOperand(parts[3]);
+			ChainType type = ChainType.valueOf(parts[4]);
+			int k = Integer.parseInt(parts[5]);
+			return new MMChainFEDInstruction(in1, in2, null, out, type, k, opcode, str);
+		}
+		else { //parts.length==7
+			CPOperand in3 = new CPOperand(parts[3]);
+			CPOperand out = new CPOperand(parts[4]);
+			ChainType type = ChainType.valueOf(parts[5]);
+			int k = Integer.parseInt(parts[6]);
+			return new MMChainFEDInstruction(in1, in2, in3, out, type, k, opcode, str);
+		}
+	}
+	
+	@Override
+	public void processInstruction(ExecutionContext ec) {
+		MatrixObject mo1 = ec.getMatrixObject(input1);
+		MatrixObject mo2 = ec.getMatrixObject(input2);
+		MatrixObject mo3 = _type.isWeighted() ? ec.getMatrixObject(input3) : null;
+		
+		if( !mo1.isFederated() )
+			throw new DMLRuntimeException("Federated MMChain: Federated main input expected, "
+				+ "but invoked w/ "+mo1.isFederated()+" "+mo2.isFederated());
+	
+		if( !_type.isWeighted() ) { //XtXv
+			//construct commands: broadcast vector, execute, get and aggregate, cleanup
+			FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+			FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+				new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
+			FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+			
+			//execute federated operations and aggregate
+			Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
+			MatrixBlock ret = FederationUtils.aggAdd(tmp);
+			mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
+			ec.setMatrixOutput(output.getName(), ret);
+		}
+		else { //XtwXv | XtXvy
+			//construct commands: broadcast 2 vectors, execute, get and aggregate, cleanup
+			FederatedRequest[] fr0 = mo1.getFedMapping().broadcastSliced(mo3, false);
+			FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
+			FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+				new CPOperand[]{input1, input2, input3},
+				new long[]{mo1.getFedMapping().getID(), fr1.getID(), fr0[0].getID()});
+			FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
+			
+			//execute federated operations and aggregate
+			Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr0, fr1, fr2, fr3);
+			MatrixBlock ret = FederationUtils.aggAdd(tmp);
+			mo1.getFedMapping().cleanup(getTID(), fr0[0].getID(), fr1.getID(), fr2.getID());
+			ec.setMatrixOutput(output.getName(), ret);
+		}
+	}
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedGLMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedGLMTest.java
new file mode 100644
index 0000000..fe24bc8
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedGLMTest.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+@RunWith(value = Parameterized.class)
+@net.jcip.annotations.NotThreadSafe
+public class FederatedGLMTest extends AutomatedTestBase {
+
+	private final static String TEST_DIR = "functions/federated/";
+	private final static String TEST_NAME = "FederatedGLMTest";
+	private final static String TEST_CLASS_DIR = TEST_DIR + FederatedGLMTest.class.getSimpleName() + "/";
+
+	private final static int blocksize = 1024;
+	@Parameterized.Parameter()
+	public int rows;
+	@Parameterized.Parameter(1)
+	public int cols;
+
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"Z"}));
+	}
+
+	@Parameterized.Parameters
+	public static Collection<Object[]> data() {
+		// rows have to be even and > 1
+		return Arrays.asList(new Object[][] {{10000, 10}, {1000, 100}, {2000, 43}});
+	}
+
+	@Test
+	public void federatedSinglenodeGLM() {
+		federatedGLM(Types.ExecMode.SINGLE_NODE);
+	}
+
+	@Test
+	public void federatedHybridGLM() {
+		federatedGLM(Types.ExecMode.HYBRID);
+	}
+
+	
+	public void federatedGLM(Types.ExecMode execMode) {
+		ExecMode platformOld = setExecMode(execMode);
+
+		getAndLoadTestConfiguration(TEST_NAME);
+		String HOME = SCRIPT_DIR + TEST_DIR;
+
+		// write input matrices
+		int halfRows = rows / 2;
+		// We have two matrices handled by a single federated worker
+		double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 1, 42);
+		double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 1, 1340);
+		double[][] Y = getRandomMatrix(rows, 1, -1, 1, 1, 1233);
+		for(int i = 0; i < rows; i++)
+			Y[i][0] = (Y[i][0] > 0) ? 1 : -1;
+
+		writeInputMatrixWithMTD("X1", X1, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+		writeInputMatrixWithMTD("X2", X2, false, new MatrixCharacteristics(halfRows, cols, blocksize, halfRows * cols));
+		writeInputMatrixWithMTD("Y", Y, false, new MatrixCharacteristics(rows, 1, blocksize, rows));
+
+		// empty script name because we don't execute any script, just start the worker
+		fullDMLScriptName = "";
+		int port1 = getRandomAvailablePort();
+		int port2 = getRandomAvailablePort();
+		Thread t1 = startLocalFedWorker(port1);
+		Thread t2 = startLocalFedWorker(port2);
+
+		TestConfiguration config = availableTestConfigurations.get(TEST_NAME);
+		loadTestConfiguration(config);
+		setOutputBuffering(false);
+		
+		// Run reference dml script with normal matrix
+		fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+		programArgs = new String[] {"-args", input("X1"), input("X2"), input("Y"), expected("Z")};
+		runTest(true, false, null, -1);
+
+		// Run actual dml script with federated matrix
+		fullDMLScriptName = HOME + TEST_NAME + ".dml";
+		programArgs = new String[] {"-stats",
+			"-nvargs", "in_X1=" + TestUtils.federatedAddress(port1, input("X1")),
+			"in_X2=" + TestUtils.federatedAddress(port2, input("X2")), "rows=" + rows, "cols=" + cols,
+			"in_Y=" + input("Y"), "out=" + output("Z")};
+		runTest(true, false, null, -1);
+
+		// compare via files
+		compareResults(1e-9);
+
+		TestUtils.shutdownThreads(t1, t2);
+
+		// check for federated operations
+		Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
+		Assert.assertTrue(heavyHittersContainsString("fed_uark+","fed_uarsqk+"));
+		Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
+		Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+		Assert.assertTrue(heavyHittersContainsString("fed_mmchain"));
+		
+		//check that federated input files are still existing
+		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+		Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+		
+		resetExecMode(platformOld);
+	}
+}
diff --git a/src/test/scripts/functions/federated/FederatedGLMTest.dml b/src/test/scripts/functions/federated/FederatedGLMTest.dml
new file mode 100644
index 0000000..aa23b5e
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedGLMTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(addresses=list($in_X1, $in_X2),
+    ranges=list(list(0, 0), list($rows / 2, $cols), list($rows / 2, 0), list($rows, $cols)))
+Y = read($in_Y)
+
+model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
+write(model, $out)
diff --git a/src/test/scripts/functions/federated/FederatedGLMTestReference.dml b/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
new file mode 100644
index 0000000..a307c8c
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedGLMTestReference.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($1), read($2))
+Y = read($3)
+model = glm(X=X, Y=Y, icpt = FALSE, tol = 1e-6, reg = 0.01)
+write(model, $4)