You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/08/14 13:09:19 UTC

[systemds] branch master updated: [SYSTEMDS-2618] Rework federated backend (UDF support for custom ops)

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 6965dda  [SYSTEMDS-2618] Rework federated backend (UDF support for custom ops)
6965dda is described below

commit 6965dda6900e6dc11872f8d9a18ff093d7ebc2a3
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Fri Aug 14 15:08:26 2020 +0200

    [SYSTEMDS-2618] Rework federated backend (UDF support for custom ops)
    
    This patch finalizes the rework of the federated runtime backend to only
    rely on few meta operations. So far we supported read, put, get, and
    instructions. Unfortunately, relying purely on instruction is not
    possible for more complex operations. Inspired by Spark's RDD
    operations, we extend this by a simple UDF mechanism where local UDF
    functions objects for custom operations are constructed at the driver,
    serialized, and shipped to the federated workers. This patch introduces
    the related UDF framework and re-implements the recently added transform
    encode via such UDFs.
---
 .../controlprogram/context/ExecutionContext.java   |  4 +-
 .../controlprogram/federated/FederatedRequest.java |  9 ++-
 .../controlprogram/federated/FederatedUDF.java     | 53 ++++++++++++++
 .../federated/FederatedWorkerHandler.java          | 81 ++++++----------------
 ...tiReturnParameterizedBuiltinFEDInstruction.java | 74 +++++++++++++++++++-
 5 files changed, 151 insertions(+), 70 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 7be3bfd..fcb5db3 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -560,7 +560,7 @@ public class ExecutionContext {
 		return null;
 	}
 	
-	private static CacheableData<?> createMatrixObject(MatrixBlock mb) {
+	public static MatrixObject createMatrixObject(MatrixBlock mb) {
 		MatrixObject ret = new MatrixObject(Types.ValueType.FP64, 
 			OptimizerUtils.getUniqueTempFileName());
 		ret.acquireModify(mb);
@@ -572,7 +572,7 @@ public class ExecutionContext {
 		return ret;
 	}
 	
-	private static CacheableData<?> createFrameObject(FrameBlock fb) {
+	public static FrameObject createFrameObject(FrameBlock fb) {
 		FrameObject ret = new FrameObject(OptimizerUtils.getUniqueTempFileName());
 		ret.acquireModify(fb);
 		ret.setMetaData(new MetaDataFormat(new MatrixCharacteristics(
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 24be89f..5880851 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -31,12 +31,11 @@ public class FederatedRequest implements Serializable {
 	
 	// commands sent to and excuted by federated workers
 	public enum RequestType {
-		READ_VAR, // create variable for local data, read on first access
-		PUT_VAR,  // receive data from main and store to local variable
-		GET_VAR,  // return local variable to main
+		READ_VAR,  // create variable for local data, read on first access
+		PUT_VAR,   // receive data from main and store to local variable
+		GET_VAR,   // return local variable to main
 		EXEC_INST, // execute arbitrary instruction over
-		FRAME_ENCODE, // TODO replace with user defined functions
-		CREATE_ENCODER
+		EXEC_UDF,  // execute arbitrary user-defined function
 	}
 	
 	private RequestType _method;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedUDF.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedUDF.java
new file mode 100644
index 0000000..5423ffa
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedUDF.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated;
+
+import java.io.Serializable;
+
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.cp.Data;
+
+public abstract class FederatedUDF implements Serializable {
+	private static final long serialVersionUID = 799416525191257308L;
+	
+	private final long[] _inputIDs;
+	
+	protected FederatedUDF(long[] inIDs) {
+		_inputIDs = inIDs;
+	}
+	
+	public final long[] getInputIDs() {
+		return _inputIDs;
+	}
+	
+	/**
+	 * Execute the user-defined function on a set of data objects
+	 * (e.g., matrix objects, frame objects, or scalars), which are
+	 * looked up by specified input IDs and passed in the same order.
+	 * 
+	 * Output data objects (potentially many) can be directly added
+	 * to the passed execution context and its variable map.
+	 * 
+	 * @param ec execution context
+	 * @param data one or many data objects
+	 * @return federated response, with none or many output objects
+	 */
+	public abstract FederatedResponse execute(ExecutionContext ec, Data... data);
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index e2332e2..f14bbb0 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -31,7 +31,6 @@ import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.FileFormat;
 import org.apache.sysds.conf.ConfigurationManager;
-import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.parser.DataExpression;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
@@ -48,15 +47,11 @@ import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
-import org.apache.sysds.runtime.matrix.data.FrameBlock;
-import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.meta.MetaDataFormat;
 import org.apache.sysds.runtime.privacy.DMLPrivacyException;
 import org.apache.sysds.runtime.privacy.PrivacyMonitor;
 import org.apache.sysds.runtime.privacy.PrivacyPropagator;
-import org.apache.sysds.runtime.transform.encode.Encoder;
-import org.apache.sysds.runtime.transform.encode.EncoderFactory;
 import org.apache.sysds.utils.JSONHelper;
 import org.apache.wink.json4j.JSONObject;
 
@@ -115,10 +110,8 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 					return getVariable(request);
 				case EXEC_INST:
 					return execInstruction(request);
-				case CREATE_ENCODER:
-					return createFrameEncoder(request);
-				case FRAME_ENCODE:
-					return executeFrameEncode(request);
+				case EXEC_UDF:
+					return execUDF(request);
 				default:
 					String message = String.format("Method %s is not supported.", method);
 					return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException(message));
@@ -133,58 +126,6 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 				+ ex.getClass() + " thrown when processing request", ex));
 		}
 	}
-
-	private FederatedResponse createFrameEncoder(FederatedRequest request) {
-		// param parsing
-		checkNumParams(request.getNumParams(), 2);
-		String spec = (String) request.getParam(0);
-		int globalOffset = (int) request.getParam(1);
-		long varID = request.getID();
-
-		Data dataObject = _ec.getVariable(String.valueOf(varID));
-		FrameObject fo = (FrameObject) PrivacyMonitor.handlePrivacy(dataObject);
-		FrameBlock data = fo.acquireRead();
-		String[] colNames = data.getColumnNames();
-
-		// create the encoder
-		Encoder encoder = EncoderFactory.createEncoder(spec, colNames,
-			data.getNumColumns(), null, globalOffset, globalOffset + data.getNumColumns());
-		// build necessary structures for encoding
-		encoder.build(data);
-		// otherwise data of FrameBlock would be null, therefore it would fail
-		// hack because serialization of FrameBlock does not function if Arrays are not allocated
-		fo.release();
-
-		return new FederatedResponse(ResponseType.SUCCESS, encoder);
-	}
-
-	private FederatedResponse executeFrameEncode(FederatedRequest request) {
-		checkNumParams(request.getNumParams(), 2);
-		Encoder encoder = (Encoder) request.getParam(0);
-		long newVarID = (long) request.getParam(1);
-		long varID = request.getID();
-
-		Data dataObject = _ec.getVariable(String.valueOf(varID));
-		FrameObject fo = (FrameObject) PrivacyMonitor.handlePrivacy(dataObject);
-		FrameBlock data = fo.acquireRead();
-
-		// apply transformation
-		MatrixBlock mbout = encoder.apply(data, new MatrixBlock(data.getNumRows(), data.getNumColumns(), false));
-
-		// copy characteristics
-		MatrixCharacteristics mc = new MatrixCharacteristics(fo.getDataCharacteristics());
-		MatrixObject mo = new MatrixObject(Types.ValueType.FP64, OptimizerUtils.getUniqueTempFileName(),
-			new MetaDataFormat(mc, FileFormat.BINARY));
-		// set the encoded data
-		mo.acquireModify(mbout);
-		mo.release();
-		fo.release();
-
-		// add it to the list of variables
-		_ec.setVariable(String.valueOf(newVarID), mo);
-		// return id handle
-		return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
-	}
 	
 	private FederatedResponse readData(FederatedRequest request) {
 		checkNumParams(request.getNumParams(), 2);
@@ -309,6 +250,24 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 		}
 		return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
 	}
+	
+	private FederatedResponse execUDF(FederatedRequest request) {
+		checkNumParams(request.getNumParams(), 1);
+		
+		//get function and input parameters
+		FederatedUDF udf = (FederatedUDF) request.getParam(0);
+		Data[] inputs = Arrays.stream(udf.getInputIDs())
+			.mapToObj(id -> _ec.getVariable(String.valueOf(id)))
+			.toArray(Data[]::new);
+		
+		//execute user-defined function
+		try {
+			return udf.execute(_ec, inputs);
+		}
+		catch(Exception ex) {
+			return new FederatedResponse(ResponseType.ERROR, ex.getMessage());
+		}
+	}
 
 	private static void checkNumParams(int actual, int... expected) {
 		if (Arrays.stream(expected).anyMatch(x -> x == actual))
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index 0aad335..5d25729 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -30,15 +30,21 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.ResponseType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.privacy.PrivacyMonitor;
 import org.apache.sysds.runtime.transform.encode.Encoder;
 import org.apache.sysds.runtime.transform.encode.EncoderComposite;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
 import org.apache.sysds.runtime.transform.encode.EncoderPassThrough;
 import org.apache.sysds.runtime.transform.encode.EncoderRecode;
 
@@ -92,7 +98,8 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
 			// create an encoder with the given spec. The columnOffset (which is 1 based) has to be used to
 			// tell the federated worker how much the indexes in the spec have to be offset.
 			Future<FederatedResponse> response = data.executeFederatedOperation(
-				new FederatedRequest(RequestType.CREATE_ENCODER, data.getVarID(), spec, columnOffset));
+				new FederatedRequest(RequestType.EXEC_UDF, data.getVarID(),
+					new CreateFrameEncoder(data.getVarID(), spec, columnOffset)));
 			// collect responses with encoders
 			try {
 				Encoder encoder = (Encoder) response.get().getData()[0];
@@ -114,7 +121,8 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
 			Encoder encoder = globalEncoder.subRangeEncoder(colStart, colEnd);
 			try {
 				FederatedResponse response = data.executeFederatedOperation(
-					new FederatedRequest(RequestType.FRAME_ENCODE, data.getVarID(), encoder, varID)).get();
+					new FederatedRequest(RequestType.EXEC_UDF, varID,
+						new ExecuteFrameEncoder(data.getVarID(), varID, encoder))).get();
 				if(!response.isSuccessful())
 					response.throwExceptionFromResponse();
 			}
@@ -134,4 +142,66 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
 		ec.setFrameOutput(getOutput(1).getName(),
 			globalEncoder.getMetaData(new FrameBlock(globalEncoder.getNumCols(), Types.ValueType.STRING)));
 	}
+	
+	
+	public static class CreateFrameEncoder extends FederatedUDF {
+		private static final long serialVersionUID = 2376756757742169692L;
+		private final String _spec;
+		private final int _offset;
+		
+		public CreateFrameEncoder(long input, String spec, int offset) {
+			super(new long[]{input});
+			_spec = spec;
+			_offset = offset;
+		}
+
+		@Override
+		public FederatedResponse execute(ExecutionContext ec, Data... data) {
+			FrameObject fo = (FrameObject) PrivacyMonitor.handlePrivacy(data[0]);
+			FrameBlock fb = fo.acquireRead();
+			String[] colNames = fb.getColumnNames();
+
+			// create the encoder
+			Encoder encoder = EncoderFactory.createEncoder(_spec, colNames,
+				fb.getNumColumns(), null, _offset, _offset + fb.getNumColumns());
+			
+			// build necessary structures for encoding
+			encoder.build(fb);
+			fo.release();
+
+			// create federated response
+			return new FederatedResponse(ResponseType.SUCCESS, encoder);
+		}
+	}
+
+	public static class ExecuteFrameEncoder extends FederatedUDF {
+		private static final long serialVersionUID = 6034440964680578276L;
+		private final long _outputID;
+		private final Encoder _encoder;
+		
+		public ExecuteFrameEncoder(long input, long output, Encoder encoder) {
+			super(new long[] {input});
+			_outputID = output;
+			_encoder = encoder;
+		}
+
+		@Override
+		public FederatedResponse execute(ExecutionContext ec, Data... data) {
+			FrameObject fo = (FrameObject) PrivacyMonitor.handlePrivacy(data[0]);
+			FrameBlock fb = fo.acquireReadAndRelease();
+
+			// apply transformation
+			MatrixBlock mbout = _encoder.apply(fb,
+				new MatrixBlock(fb.getNumRows(), fb.getNumColumns(), false));
+
+			// create output matrix object
+			MatrixObject mo = ExecutionContext.createMatrixObject(mbout);
+
+			// add it to the list of variables
+			ec.setVariable(String.valueOf(_outputID), mo);
+		
+			// return id handle
+			return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
+		}
+	}
 }