You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/08/04 02:16:17 UTC
[2/3] systemml git commit: [MINOR] Various paramserv refactorings and
code cleanups
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/SparkDataPartitioner.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/SparkDataPartitioner.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/SparkDataPartitioner.java
new file mode 100644
index 0000000..031150b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/SparkDataPartitioner.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.util.DataConverter;
+
+public class SparkDataPartitioner implements Serializable {
+
+ private static final long serialVersionUID = 6841548626711057448L;
+ private DataPartitionSparkScheme _scheme;
+
+ protected SparkDataPartitioner(Statement.PSScheme scheme, SparkExecutionContext sec, int numEntries, int numWorkers) {
+ switch (scheme) {
+ case DISJOINT_CONTIGUOUS:
+ _scheme = new DCSparkScheme();
+ // Create the worker id indicator
+ createDCIndicator(sec, numWorkers, numEntries);
+ break;
+ case DISJOINT_ROUND_ROBIN:
+ _scheme = new DRRSparkScheme();
+ // Create the worker id indicator
+ createDRIndicator(sec, numWorkers, numEntries);
+ break;
+ case DISJOINT_RANDOM:
+ _scheme = new DRSparkScheme();
+ // Create the global permutation
+ createGlobalPermutations(sec, numEntries, 1);
+ // Create the worker id indicator
+ createDCIndicator(sec, numWorkers, numEntries);
+ break;
+ case OVERLAP_RESHUFFLE:
+ _scheme = new ORSparkScheme();
+ // Create the global permutation seperately for each worker
+ createGlobalPermutations(sec, numEntries, numWorkers);
+ break;
+ }
+ }
+
+ private void createDRIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) {
+ double[] vector = IntStream.range(0, numEntries).mapToDouble(n -> n % numWorkers).toArray();
+ MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
+ _scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB)));
+ }
+
+ private void createDCIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) {
+ double[] vector = new double[numEntries];
+ int batchSize = (int) Math.ceil((double) numEntries / numWorkers);
+ for (int i = 1; i < numWorkers; i++) {
+ int begin = batchSize * i;
+ int end = Math.min(begin + batchSize, numEntries);
+ Arrays.fill(vector, begin, end, i);
+ }
+ MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
+ _scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB)));
+ }
+
+ private void createGlobalPermutations(SparkExecutionContext sec, int numEntries, int numPerm) {
+ List<PartitionedBroadcast<MatrixBlock>> perms = IntStream.range(0, numPerm).mapToObj(i -> {
+ MatrixBlock perm = MatrixBlock.sampleOperations(numEntries, numEntries, false, SEED+i);
+ // Create the source-target id vector from the permutation ranging from 1 to number of entries
+ double[] vector = new double[numEntries];
+ for (int j = 0; j < perm.getDenseBlockValues().length; j++) {
+ vector[(int) perm.getDenseBlockValues()[j] - 1] = j;
+ }
+ MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
+ return sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB));
+ }).collect(Collectors.toList());
+ _scheme.setGlobalPermutation(perms);
+ }
+
+ public DataPartitionSparkScheme.Result doPartitioning(int numWorkers, MatrixBlock features, MatrixBlock labels,
+ long rowID) {
+ // Set the rowID in order to get the according permutation
+ return _scheme.doPartitioning(numWorkers, (int) rowID, features, labels);
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcCall.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcCall.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcCall.java
new file mode 100644
index 0000000..8b0540b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcCall.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.rpc;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.CacheDataOutput;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.util.ByteBufferDataInput;
+
+public class PSRpcCall extends PSRpcObject {
+
+ private int _method;
+ private int _workerID;
+ private ListObject _data;
+
+ public PSRpcCall(int method, int workerID, ListObject data) {
+ _method = method;
+ _workerID = workerID;
+ _data = data;
+ }
+
+ public PSRpcCall(ByteBuffer buffer) throws IOException {
+ deserialize(buffer);
+ }
+
+ public int getMethod() {
+ return _method;
+ }
+
+ public int getWorkerID() {
+ return _workerID;
+ }
+
+ public ListObject getData() {
+ return _data;
+ }
+
+ public void deserialize(ByteBuffer buffer) throws IOException {
+ ByteBufferDataInput dis = new ByteBufferDataInput(buffer);
+ _method = dis.readInt();
+ validateMethod(_method);
+ _workerID = dis.readInt();
+ if (dis.available() > 1)
+ _data = readAndDeserialize(dis);
+ }
+
+ public ByteBuffer serialize() throws IOException {
+ int len = 8 + getExactSerializedSize(_data);
+ CacheDataOutput dos = new CacheDataOutput(len);
+ dos.writeInt(_method);
+ dos.writeInt(_workerID);
+ if (_data != null)
+ serializeAndWriteListObject(_data, dos);
+ return ByteBuffer.wrap(dos.getBytes());
+ }
+
+ private void validateMethod(int method) {
+ switch (method) {
+ case PUSH:
+ case PULL:
+ break;
+ default:
+ throw new DMLRuntimeException("PSRpcCall: only support rpc method 'push' or 'pull'");
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcFactory.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcFactory.java
new file mode 100644
index 0000000..a7db756
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcFactory.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.rpc;
+
+import java.io.IOException;
+import java.util.Collections;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.netty.SparkTransportConf;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.util.LongAccumulator;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSProxy;
+
+public class PSRpcFactory {
+
+ private static final String MODULE_NAME = "ps";
+
+ private static TransportContext createTransportContext(SparkConf conf, LocalParamServer ps) {
+ TransportConf tc = SparkTransportConf.fromSparkConf(conf, MODULE_NAME, 0);
+ PSRpcHandler handler = new PSRpcHandler(ps);
+ return new TransportContext(tc, handler);
+ }
+
+ /**
+ * Create and start the server
+ * @return server
+ */
+ public static TransportServer createServer(SparkConf conf, LocalParamServer ps, String host) {
+ TransportContext context = createTransportContext(conf, ps);
+ return context.createServer(host, 0, Collections.emptyList()); // bind rpc to an ephemeral port
+ }
+
+ public static SparkPSProxy createSparkPSProxy(SparkConf conf, int port, LongAccumulator aRPC) throws IOException {
+ long rpcTimeout = conf.contains("spark.rpc.askTimeout") ?
+ conf.getTimeAsMs("spark.rpc.askTimeout") :
+ conf.getTimeAsMs("spark.network.timeout", "120s");
+ String host = conf.get("spark.driver.host");
+ TransportContext context = createTransportContext(conf, new LocalParamServer());
+ return new SparkPSProxy(context.createClientFactory().createClient(host, port), rpcTimeout, aRPC);
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcHandler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcHandler.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcHandler.java
new file mode 100644
index 0000000..cf8de6d
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcHandler.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.rpc;
+
+import static org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcCall.PULL;
+import static org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcCall.PUSH;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcResponse.Type;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public final class PSRpcHandler extends RpcHandler {
+
+ private LocalParamServer _server;
+
+ protected PSRpcHandler(LocalParamServer server) {
+ _server = server;
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer buffer, RpcResponseCallback callback) {
+ PSRpcCall call;
+ try {
+ call = new PSRpcCall(buffer);
+ } catch (IOException e) {
+ throw new DMLRuntimeException("PSRpcHandler: some error occurred when deserializing the rpc call.", e);
+ }
+ PSRpcResponse response = null;
+ switch (call.getMethod()) {
+ case PUSH:
+ try {
+ _server.push(call.getWorkerID(), call.getData());
+ response = new PSRpcResponse(Type.SUCCESS_EMPTY);
+ } catch (DMLRuntimeException exception) {
+ response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception));
+ } finally {
+ try {
+ callback.onSuccess(response.serialize());
+ } catch (IOException e) {
+ throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e);
+ }
+ }
+ break;
+ case PULL:
+ ListObject data;
+ try {
+ data = _server.pull(call.getWorkerID());
+ response = new PSRpcResponse(Type.SUCCESS, data);
+ } catch (DMLRuntimeException exception) {
+ response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception));
+ } finally {
+ try {
+ callback.onSuccess(response.serialize());
+ } catch (IOException e) {
+ throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e);
+ }
+ }
+ break;
+ default:
+ throw new DMLRuntimeException(String.format("Does not support the rpc call for method %s", call.getMethod()));
+ }
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return new OneForOneStreamManager();
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcObject.java
new file mode 100644
index 0000000..38d80a2
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcObject.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.rpc;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.io.IOUtilFunctions;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+public abstract class PSRpcObject {
+
+ public static final int PUSH = 1;
+ public static final int PULL = 2;
+
+ public abstract void deserialize(ByteBuffer buffer) throws IOException;
+
+ public abstract ByteBuffer serialize() throws IOException;
+
+ /**
+ * Deep serialize and write of a list object (currently only support list containing matrices)
+ * @param lo a list object containing only matrices
+ * @param output output data to write to
+ */
+ protected void serializeAndWriteListObject(ListObject lo, DataOutput output) throws IOException {
+ validateListObject(lo);
+ output.writeInt(lo.getLength()); //write list length
+ output.writeBoolean(lo.isNamedList()); //write list named
+ for (int i = 0; i < lo.getLength(); i++) {
+ if (lo.isNamedList())
+ output.writeUTF(lo.getName(i)); //write name
+ ((MatrixObject) lo.getData().get(i))
+ .acquireReadAndRelease().write(output); //write matrix
+ }
+ // Cleanup the list object
+ // because it is transferred to remote worker in binary format
+ ParamservUtils.cleanupListObject(lo);
+ }
+
+ protected ListObject readAndDeserialize(DataInput input) throws IOException {
+ int listLen = input.readInt();
+ List<Data> data = new ArrayList<>();
+ List<String> names = input.readBoolean() ?
+ new ArrayList<>() : null;
+ for(int i=0; i<listLen; i++) {
+ if( names != null )
+ names.add(input.readUTF());
+ MatrixBlock mb = new MatrixBlock();
+ mb.readFields(input);
+ data.add(ParamservUtils.newMatrixObject(mb, false));
+ }
+ return new ListObject(data, names);
+ }
+
+ /**
+ * Get serialization size of a list object
+ * (scheme: size|name|size|matrix)
+ * @param lo list object
+ * @return serialization size
+ */
+ protected int getExactSerializedSize(ListObject lo) {
+ if( lo == null ) return 0;
+ long result = 4 + 1; // list length and of named
+ if (lo.isNamedList()) //size for names incl length
+ result += lo.getNames().stream().mapToLong(s -> IOUtilFunctions.getUTFSize(s)).sum();
+ result += lo.getData().stream().mapToLong(d ->
+ ((MatrixObject)d).acquireReadAndRelease().getExactSizeOnDisk()).sum();
+ if( result > Integer.MAX_VALUE )
+ throw new DMLRuntimeException("Serialized size ("+result+") larger than Integer.MAX_VALUE.");
+ return (int) result;
+ }
+
+ private void validateListObject(ListObject lo) {
+ for (Data d : lo.getData()) {
+ if (!(d instanceof MatrixObject)) {
+ throw new DMLRuntimeException(String.format("Paramserv func:"
+ + " Unsupported deep serialize of %s, which is not matrix.", d.getDebugName()));
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcResponse.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcResponse.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcResponse.java
new file mode 100644
index 0000000..68e1dd1
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcResponse.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.rpc;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import org.apache.sysml.runtime.util.ByteBufferDataInput;
+import org.apache.sysml.runtime.controlprogram.caching.CacheDataOutput;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.io.IOUtilFunctions;
+
+public class PSRpcResponse extends PSRpcObject {
+ public enum Type {
+ SUCCESS,
+ SUCCESS_EMPTY,
+ ERROR,
+ }
+
+ private Type _status;
+ private Object _data; // Could be list object or exception
+
+ public PSRpcResponse(ByteBuffer buffer) throws IOException {
+ deserialize(buffer);
+ }
+
+ public PSRpcResponse(Type status) {
+ this(status, null);
+ }
+
+ public PSRpcResponse(Type status, Object data) {
+ _status = status;
+ _data = data;
+ if( _status == Type.SUCCESS && data == null )
+ _status = Type.SUCCESS_EMPTY;
+ }
+
+ public boolean isSuccessful() {
+ return _status != Type.ERROR;
+ }
+
+ public String getErrorMessage() {
+ return (String) _data;
+ }
+
+ public ListObject getResultModel() {
+ return (ListObject) _data;
+ }
+
+ @Override
+ public void deserialize(ByteBuffer buffer) throws IOException {
+ ByteBufferDataInput dis = new ByteBufferDataInput(buffer);
+ _status = Type.values()[dis.readInt()];
+ switch (_status) {
+ case SUCCESS:
+ _data = readAndDeserialize(dis);
+ break;
+ case SUCCESS_EMPTY:
+ break;
+ case ERROR:
+ _data = dis.readUTF();
+ break;
+ }
+ }
+
+ @Override
+ public ByteBuffer serialize() throws IOException {
+ int len = 4 + (_status==Type.SUCCESS ? getExactSerializedSize((ListObject)_data) :
+ _status==Type.SUCCESS_EMPTY ? 0 : IOUtilFunctions.getUTFSize((String)_data));
+ CacheDataOutput dos = new CacheDataOutput(len);
+ dos.writeInt(_status.ordinal());
+ switch (_status) {
+ case SUCCESS:
+ serializeAndWriteListObject((ListObject) _data, dos);
+ break;
+ case SUCCESS_EMPTY:
+ break;
+ case ERROR:
+ dos.writeUTF(_data.toString());
+ break;
+ }
+ return ByteBuffer.wrap(dos.getBytes());
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DCSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DCSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DCSparkScheme.java
deleted file mode 100644
index 666b891..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DCSparkScheme.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.util.List;
-
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-/**
- * Spark Disjoint_Contiguous data partitioner:
- * <p>
- * For each row, find out the shifted place according to the workerID indicator
- */
-public class DCSparkScheme extends DataPartitionSparkScheme {
-
- private static final long serialVersionUID = -2786906947020788787L;
-
- protected DCSparkScheme() {
- // No-args constructor used for deserialization
- }
-
- @Override
- public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = nonShuffledPartition(rblkID, features);
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = nonShuffledPartition(rblkID, labels);
- return new Result(pfs, pls);
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRRSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRRSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRRSparkScheme.java
deleted file mode 100644
index 7683251..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRRSparkScheme.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.util.List;
-
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-/**
- * Spark Disjoint_Round_Robin data partitioner:
- */
-public class DRRSparkScheme extends DataPartitionSparkScheme {
-
- private static final long serialVersionUID = -3130831851505549672L;
-
- protected DRRSparkScheme() {
- // No-args constructor used for deserialization
- }
-
- @Override
- public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = nonShuffledPartition(rblkID, features);
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = nonShuffledPartition(rblkID, labels);
- return new Result(pfs, pls);
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRSparkScheme.java
deleted file mode 100644
index 51cc523..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRSparkScheme.java
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-/**
- * Spark data partitioner Disjoint_Random:
- *
- * For the current row block, find all the shifted place for each row (WorkerID => (row block ID, matrix)
- */
-public class DRSparkScheme extends DataPartitionSparkScheme {
-
- private static final long serialVersionUID = -7655310624144544544L;
-
- protected DRSparkScheme() {
- // No-args constructor used for deserialization
- }
-
- @Override
- public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = partition(rblkID, features);
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = partition(rblkID, labels);
- return new Result(pfs, pls);
- }
-
- private List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> partition(int rblkID, MatrixBlock mb) {
- MatrixBlock partialPerm = _globalPerms.get(0).getBlock(rblkID, 1);
-
- // For each row, find out the shifted place
- return IntStream.range(0, mb.getNumRows()).mapToObj(r -> {
- MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1);
- long shiftedPosition = (long) partialPerm.getValue(r, 0);
-
- // Get the shifted block and position
- int shiftedBlkID = (int) (shiftedPosition / OptimizerUtils.DEFAULT_BLOCKSIZE + 1);
-
- MatrixBlock indicator = _workerIndicator.getBlock(shiftedBlkID, 1);
- int workerID = (int) indicator.getValue((int) shiftedPosition / OptimizerUtils.DEFAULT_BLOCKSIZE, 0);
- return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB));
- }).collect(Collectors.toList());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionSparkScheme.java
deleted file mode 100644
index 9875dd2..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionSparkScheme.java
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.io.Serializable;
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.LongStream;
-
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-public abstract class DataPartitionSparkScheme implements Serializable {
-
- protected final class Result {
- protected final List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pFeatures; // WorkerID => (rowID, matrix)
- protected final List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pLabels;
-
- protected Result(List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pFeatures, List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pLabels) {
- this.pFeatures = pFeatures;
- this.pLabels = pLabels;
- }
- }
-
- private static final long serialVersionUID = -3462829818083371171L;
-
- protected List<PartitionedBroadcast<MatrixBlock>> _globalPerms; // a list of global permutations
- protected PartitionedBroadcast<MatrixBlock> _workerIndicator; // a matrix indicating to which worker the given row belongs
-
- protected void setGlobalPermutation(List<PartitionedBroadcast<MatrixBlock>> gps) {
- _globalPerms = gps;
- }
-
- protected void setWorkerIndicator(PartitionedBroadcast<MatrixBlock> wi) {
- _workerIndicator = wi;
- }
-
- /**
- * Do non-reshuffled data partitioning according to worker indicator
- * @param rblkID row block ID
- * @param mb Matrix
- * @return list of tuple (workerID, (row block ID, matrix row))
- */
- protected List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> nonShuffledPartition(int rblkID, MatrixBlock mb) {
- MatrixBlock indicator = _workerIndicator.getBlock(rblkID, 1);
- return LongStream.range(0, mb.getNumRows()).mapToObj(r -> {
- int workerID = (int) indicator.getValue((int) r, 0);
- MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1);
- long shiftedPosition = r + (rblkID - 1) * OptimizerUtils.DEFAULT_BLOCKSIZE;
- return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB));
- }).collect(Collectors.toList());
- }
-
- public abstract Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels);
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkAggregator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkAggregator.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkAggregator.java
deleted file mode 100644
index 39b8adf..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkAggregator.java
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.io.Serializable;
-import java.util.LinkedList;
-
-import org.apache.spark.api.java.function.PairFunction;
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-public class DataPartitionerSparkAggregator implements PairFunction<Tuple2<Integer,LinkedList<Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>>>>, Integer, Tuple2<MatrixBlock, MatrixBlock>>, Serializable {
-
- private static final long serialVersionUID = -1245300852709085117L;
- private long _fcol;
- private long _lcol;
-
- public DataPartitionerSparkAggregator() {
-
- }
-
- public DataPartitionerSparkAggregator(long fcol, long lcol) {
- _fcol = fcol;
- _lcol = lcol;
- }
-
- /**
- * Row-wise combine the matrix
- * @param input workerID => ordered list [(rowBlockID, (features, labels))]
- * @return workerID => [(features, labels)]
- * @throws Exception Some exception
- */
- @Override
- public Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> call(Tuple2<Integer, LinkedList<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>> input) throws Exception {
- MatrixBlock fmb = new MatrixBlock(input._2.size(), (int) _fcol, false);
- MatrixBlock lmb = new MatrixBlock(input._2.size(), (int) _lcol, false);
-
- for (int i = 0; i < input._2.size(); i++) {
- MatrixBlock tmpFMB = input._2.get(i)._2._1;
- MatrixBlock tmpLMB = input._2.get(i)._2._2;
- // Row-wise aggregation
- fmb = fmb.leftIndexingOperations(tmpFMB, i, i, 0, (int) _fcol - 1, fmb, MatrixObject.UpdateType.INPLACE_PINNED);
- lmb = lmb.leftIndexingOperations(tmpLMB, i, i, 0, (int) _lcol - 1, lmb, MatrixObject.UpdateType.INPLACE_PINNED);
- }
- return new Tuple2<>(input._1, new Tuple2<>(fmb, lmb));
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkMapper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkMapper.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkMapper.java
deleted file mode 100644
index 2a69986..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkMapper.java
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.io.Serializable;
-import java.util.Iterator;
-import java.util.LinkedList;
-import java.util.List;
-
-import org.apache.spark.api.java.function.PairFlatMapFunction;
-import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-public class DataPartitionerSparkMapper implements PairFlatMapFunction<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>, Integer, Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable {
-
- private static final long serialVersionUID = 1710721606050403296L;
- private int _workersNum;
-
- private SparkDataPartitioner _dp;
-
- protected DataPartitionerSparkMapper() {
- // No-args constructor used for deserialization
- }
-
- public DataPartitionerSparkMapper(Statement.PSScheme scheme, int workersNum, SparkExecutionContext sec, int numEntries) {
- _workersNum = workersNum;
- _dp = new SparkDataPartitioner(scheme, sec, numEntries, workersNum);
- }
-
- /**
- * Do data partitioning
- * @param input RowBlockID => (features, labels)
- * @return WorkerID => (rowBlockID, (single row features, single row labels))
- * @throws Exception Some exception
- */
- @Override
- public Iterator<Tuple2<Integer, Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>> call(Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>> input)
- throws Exception {
- List<Tuple2<Integer, Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>>>> partitions = new LinkedList<>();
- MatrixBlock features = input._2._1;
- MatrixBlock labels = input._2._2;
- DataPartitionSparkScheme.Result result = _dp.doPartitioning(_workersNum, features, labels, input._1);
- for (int i = 0; i < result.pFeatures.size(); i++) {
- Tuple2<Integer, Tuple2<Long, MatrixBlock>> ft = result.pFeatures.get(i);
- Tuple2<Integer, Tuple2<Long, MatrixBlock>> lt = result.pLabels.get(i);
- partitions.add(new Tuple2<>(ft._1, new Tuple2<>(ft._2._1, new Tuple2<>(ft._2._2, lt._2._2))));
- }
- return partitions.iterator();
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/ORSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/ORSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/ORSparkScheme.java
deleted file mode 100644
index 16ce516..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/ORSparkScheme.java
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-/**
- * Spark data partitioner Overlap_Reshuffle:
- *
- */
-public class ORSparkScheme extends DataPartitionSparkScheme {
-
- private static final long serialVersionUID = 6867567406403580311L;
-
- protected ORSparkScheme() {
- // No-args constructor used for deserialization
- }
-
- @Override
- public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = partition(numWorkers, rblkID, features);
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = partition(numWorkers, rblkID, labels);
- return new Result(pfs, pls);
- }
-
- private List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> partition(int numWorkers, int rblkID, MatrixBlock mb) {
- return IntStream.range(0, numWorkers).mapToObj(i -> i).flatMap(workerID -> {
- MatrixBlock partialPerm = _globalPerms.get(workerID).getBlock(rblkID, 1);
- return IntStream.range(0, mb.getNumRows()).mapToObj(r -> {
- MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1);
- long shiftedPosition = (long) partialPerm.getValue(r, 0);
- return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB));
- });
- }).collect(Collectors.toList());
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkDataPartitioner.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkDataPartitioner.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkDataPartitioner.java
deleted file mode 100644
index 6883d0f..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkDataPartitioner.java
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED;
-
-import java.io.Serializable;
-import java.util.Arrays;
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
-import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.util.DataConverter;
-
-public class SparkDataPartitioner implements Serializable {
-
- private static final long serialVersionUID = 6841548626711057448L;
- private DataPartitionSparkScheme _scheme;
-
- protected SparkDataPartitioner(Statement.PSScheme scheme, SparkExecutionContext sec, int numEntries, int numWorkers) {
- switch (scheme) {
- case DISJOINT_CONTIGUOUS:
- _scheme = new DCSparkScheme();
- // Create the worker id indicator
- createDCIndicator(sec, numWorkers, numEntries);
- break;
- case DISJOINT_ROUND_ROBIN:
- _scheme = new DRRSparkScheme();
- // Create the worker id indicator
- createDRIndicator(sec, numWorkers, numEntries);
- break;
- case DISJOINT_RANDOM:
- _scheme = new DRSparkScheme();
- // Create the global permutation
- createGlobalPermutations(sec, numEntries, 1);
- // Create the worker id indicator
- createDCIndicator(sec, numWorkers, numEntries);
- break;
- case OVERLAP_RESHUFFLE:
- _scheme = new ORSparkScheme();
- // Create the global permutation seperately for each worker
- createGlobalPermutations(sec, numEntries, numWorkers);
- break;
- }
- }
-
- private void createDRIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) {
- double[] vector = IntStream.range(0, numEntries).mapToDouble(n -> n % numWorkers).toArray();
- MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
- _scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB)));
- }
-
- private void createDCIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) {
- double[] vector = new double[numEntries];
- int batchSize = (int) Math.ceil((double) numEntries / numWorkers);
- for (int i = 1; i < numWorkers; i++) {
- int begin = batchSize * i;
- int end = Math.min(begin + batchSize, numEntries);
- Arrays.fill(vector, begin, end, i);
- }
- MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
- _scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB)));
- }
-
- private void createGlobalPermutations(SparkExecutionContext sec, int numEntries, int numPerm) {
- List<PartitionedBroadcast<MatrixBlock>> perms = IntStream.range(0, numPerm).mapToObj(i -> {
- MatrixBlock perm = MatrixBlock.sampleOperations(numEntries, numEntries, false, SEED+i);
- // Create the source-target id vector from the permutation ranging from 1 to number of entries
- double[] vector = new double[numEntries];
- for (int j = 0; j < perm.getDenseBlockValues().length; j++) {
- vector[(int) perm.getDenseBlockValues()[j] - 1] = j;
- }
- MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
- return sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB));
- }).collect(Collectors.toList());
- _scheme.setGlobalPermutation(perms);
- }
-
- public DataPartitionSparkScheme.Result doPartitioning(int numWorkers, MatrixBlock features, MatrixBlock labels,
- long rowID) {
- // Set the rowID in order to get the according permutation
- return _scheme.doPartitioning(numWorkers, (int) rowID, features, labels);
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
deleted file mode 100644
index 9354025..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-
-/**
- * Wrapper class containing all needed for launching spark remote worker
- */
-public class SparkPSBody {
-
- private ExecutionContext _ec;
-
- public SparkPSBody() {}
-
- public SparkPSBody(ExecutionContext ec) {
- _ec = ec;
- }
-
- public ExecutionContext getEc() {
- return _ec;
- }
-
- public void setEc(ExecutionContext ec) {
- this._ec = ec;
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
deleted file mode 100644
index 48a4883..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PULL;
-import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PUSH;
-
-import java.io.IOException;
-
-import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.util.LongAccumulator;
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
-import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
-import org.apache.sysml.runtime.instructions.cp.ListObject;
-
-public class SparkPSProxy extends ParamServer {
-
- private final TransportClient _client;
- private final long _rpcTimeout;
- private final LongAccumulator _aRPC;
-
- public SparkPSProxy(TransportClient client, long rpcTimeout, LongAccumulator aRPC) {
- super();
- _client = client;
- _rpcTimeout = rpcTimeout;
- _aRPC = aRPC;
- }
-
- private void accRpcRequestTime(Timing tRpc) {
- if (DMLScript.STATISTICS)
- _aRPC.add((long) tRpc.stop());
- }
-
- @Override
- public void push(int workerID, ListObject value) {
- Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
- PSRpcResponse response;
- try {
- response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, value).serialize(), _rpcTimeout));
- } catch (IOException e) {
- throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients.", workerID), e);
- }
- accRpcRequestTime(tRpc);
- if (!response.isSuccessful()) {
- throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients. \n%s", workerID, response.getErrorMessage()));
- }
- }
-
- @Override
- public ListObject pull(int workerID) {
- Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
- PSRpcResponse response;
- try {
- response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, null).serialize(), _rpcTimeout));
- } catch (IOException e) {
- throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models.", workerID), e);
- }
- accRpcRequestTime(tRpc);
- if (!response.isSuccessful()) {
- throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models. \n%s", workerID, response.getErrorMessage()));
- }
- return response.getResultModel();
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
deleted file mode 100644
index cb3e729..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
+++ /dev/null
@@ -1,168 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.function.VoidFunction;
-import org.apache.spark.util.LongAccumulator;
-import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.codegen.CodegenUtils;
-import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory;
-import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils;
-import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.util.ProgramConverter;
-
-import scala.Tuple2;
-
-public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>> {
-
- private static final long serialVersionUID = -8674739573419648732L;
-
- private final String _program;
- private final HashMap<String, byte[]> _clsMap;
- private final SparkConf _conf;
- private final int _port; // rpc port
- private final String _aggFunc;
- private final LongAccumulator _aSetup; // accumulator for setup time
- private final LongAccumulator _aWorker; // accumulator for worker number
- private final LongAccumulator _aUpdate; // accumulator for model update
- private final LongAccumulator _aIndex; // accumulator for batch indexing
- private final LongAccumulator _aGrad; // accumulator for gradients computing
- private final LongAccumulator _aRPC; // accumulator for rpc request
- private final LongAccumulator _nBatches; //number of executed batches
- private final LongAccumulator _nEpochs; //number of executed epoches
-
- public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches, LongAccumulator aEpochs) {
- _updFunc = updFunc;
- _aggFunc = aggFunc;
- _freq = freq;
- _epochs = epochs;
- _batchSize = batchSize;
- _program = program;
- _clsMap = clsMap;
- _conf = conf;
- _port = port;
- _aSetup = aSetup;
- _aWorker = aWorker;
- _aUpdate = aUpdate;
- _aIndex = aIndex;
- _aGrad = aGrad;
- _aRPC = aRPC;
- _nBatches = aBatches;
- _nEpochs = aEpochs;
- }
-
- @Override
- public String getWorkerName() {
- return String.format("Spark worker_%d", _workerID);
- }
-
- @Override
- public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception {
- Timing tSetup = new Timing(true);
- configureWorker(input);
- accSetupTime(tSetup);
-
- call(); // Launch the worker
- }
-
- private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws IOException {
- _workerID = input._1;
-
- // Initialize codegen class cache (before program parsing)
- for (Map.Entry<String, byte[]> e : _clsMap.entrySet()) {
- CodegenUtils.getClassSync(e.getKey(), e.getValue());
- }
-
- // Deserialize the body to initialize the execution context
- SparkPSBody body = ProgramConverter.parseSparkPSBody(_program, _workerID);
- _ec = body.getEc();
-
- // Initialize the buffer pool and register it in the jvm shutdown hook in order to be cleanuped at the end
- RemoteParForUtils.setupBufferPool(_workerID);
-
- // Get some configurations
- long rpcTimeout = _conf.contains("spark.rpc.askTimeout") ?
- _conf.getTimeAsMs("spark.rpc.askTimeout") :
- _conf.getTimeAsMs("spark.network.timeout", "120s");
- String host = _conf.get("spark.driver.host");
-
- // Create the ps proxy
- _ps = PSRpcFactory.createSparkPSProxy(_conf, host, _port, rpcTimeout, _aRPC);
-
- // Initialize the update function
- setupUpdateFunction(_updFunc, _ec);
-
- // Initialize the agg function
- _ps.setupAggFunc(_ec, _aggFunc);
-
- // Lazy initialize the matrix of features and labels
- setFeatures(ParamservUtils.newMatrixObject(input._2._1));
- setLabels(ParamservUtils.newMatrixObject(input._2._2));
- _features.enableCleanup(false);
- _labels.enableCleanup(false);
- }
-
-
- @Override
- protected void incWorkerNumber() {
- _aWorker.add(1);
- }
-
- @Override
- protected void accLocalModelUpdateTime(Timing time) {
- if( time != null )
- _aUpdate.add((long) time.stop());
- }
-
- @Override
- protected void accBatchIndexingTime(Timing time) {
- if( time != null )
- _aIndex.add((long) time.stop());
- }
-
- @Override
- protected void accGradientComputeTime(Timing time) {
- if( time != null )
- _aGrad.add((long) time.stop());
- }
-
- @Override
- protected void accNumEpochs(int n) {
- _nEpochs.add(n);
- }
-
- @Override
- protected void accNumBatches(int n) {
- _nBatches.add(n);
- }
-
- private void accSetupTime(Timing time) {
- if( time != null )
- _aSetup.add((long) time.stop());
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
deleted file mode 100644
index a33fda2..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.caching.CacheDataOutput;
-import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.runtime.util.ByteBufferDataInput;
-
-public class PSRpcCall extends PSRpcObject {
-
- private int _method;
- private int _workerID;
- private ListObject _data;
-
- public PSRpcCall(int method, int workerID, ListObject data) {
- _method = method;
- _workerID = workerID;
- _data = data;
- }
-
- public PSRpcCall(ByteBuffer buffer) throws IOException {
- deserialize(buffer);
- }
-
- public int getMethod() {
- return _method;
- }
-
- public int getWorkerID() {
- return _workerID;
- }
-
- public ListObject getData() {
- return _data;
- }
-
- public void deserialize(ByteBuffer buffer) throws IOException {
- ByteBufferDataInput dis = new ByteBufferDataInput(buffer);
- _method = dis.readInt();
- validateMethod(_method);
- _workerID = dis.readInt();
- if (dis.available() > 1)
- _data = readAndDeserialize(dis);
- }
-
- public ByteBuffer serialize() throws IOException {
- int len = 8 + getExactSerializedSize(_data);
- CacheDataOutput dos = new CacheDataOutput(len);
- dos.writeInt(_method);
- dos.writeInt(_workerID);
- if (_data != null)
- serializeAndWriteListObject(_data, dos);
- return ByteBuffer.wrap(dos.getBytes());
- }
-
- private void validateMethod(int method) {
- switch (method) {
- case PUSH:
- case PULL:
- break;
- default:
- throw new DMLRuntimeException("PSRpcCall: only support rpc method 'push' or 'pull'");
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
deleted file mode 100644
index 5e76d23..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
-
-import java.io.IOException;
-import java.util.Collections;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.network.TransportContext;
-import org.apache.spark.network.netty.SparkTransportConf;
-import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.util.TransportConf;
-import org.apache.spark.util.LongAccumulator;
-import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSProxy;
-
-public class PSRpcFactory {
-
- private static final String MODULE_NAME = "ps";
-
- private static TransportContext createTransportContext(SparkConf conf, LocalParamServer ps) {
- TransportConf tc = SparkTransportConf.fromSparkConf(conf, MODULE_NAME, 0);
- PSRpcHandler handler = new PSRpcHandler(ps);
- return new TransportContext(tc, handler);
- }
-
- /**
- * Create and start the server
- * @return server
- */
- public static TransportServer createServer(SparkConf conf, LocalParamServer ps, String host) {
- TransportContext context = createTransportContext(conf, ps);
- return context.createServer(host, 0, Collections.emptyList()); // bind rpc to an ephemeral port
- }
-
- public static SparkPSProxy createSparkPSProxy(SparkConf conf, String host, int port, long rpcTimeout, LongAccumulator aRPC) throws IOException {
- TransportContext context = createTransportContext(conf, new LocalParamServer());
- return new SparkPSProxy(context.createClientFactory().createClient(host, port), rpcTimeout, aRPC);
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
deleted file mode 100644
index a2c311e..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
-
-import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PULL;
-import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PUSH;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-
-import org.apache.commons.lang.exception.ExceptionUtils;
-import org.apache.spark.network.client.RpcResponseCallback;
-import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.server.OneForOneStreamManager;
-import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.server.StreamManager;
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.Type;
-import org.apache.sysml.runtime.instructions.cp.ListObject;
-
-public final class PSRpcHandler extends RpcHandler {
-
- private LocalParamServer _server;
-
- protected PSRpcHandler(LocalParamServer server) {
- _server = server;
- }
-
- @Override
- public void receive(TransportClient client, ByteBuffer buffer, RpcResponseCallback callback) {
- PSRpcCall call;
- try {
- call = new PSRpcCall(buffer);
- } catch (IOException e) {
- throw new DMLRuntimeException("PSRpcHandler: some error occurred when deserializing the rpc call.", e);
- }
- PSRpcResponse response = null;
- switch (call.getMethod()) {
- case PUSH:
- try {
- _server.push(call.getWorkerID(), call.getData());
- response = new PSRpcResponse(Type.SUCCESS_EMPTY);
- } catch (DMLRuntimeException exception) {
- response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception));
- } finally {
- try {
- callback.onSuccess(response.serialize());
- } catch (IOException e) {
- throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e);
- }
- }
- break;
- case PULL:
- ListObject data;
- try {
- data = _server.pull(call.getWorkerID());
- response = new PSRpcResponse(Type.SUCCESS, data);
- } catch (DMLRuntimeException exception) {
- response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception));
- } finally {
- try {
- callback.onSuccess(response.serialize());
- } catch (IOException e) {
- throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e);
- }
- }
- break;
- default:
- throw new DMLRuntimeException(String.format("Does not support the rpc call for method %s", call.getMethod()));
- }
- }
-
- @Override
- public StreamManager getStreamManager() {
- return new OneForOneStreamManager();
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
deleted file mode 100644
index 816cefd..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
-
-import java.io.DataInput;
-import java.io.DataOutput;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.List;
-
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.instructions.cp.Data;
-import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.runtime.io.IOUtilFunctions;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-public abstract class PSRpcObject {
-
- public static final int PUSH = 1;
- public static final int PULL = 2;
-
- public abstract void deserialize(ByteBuffer buffer) throws IOException;
-
- public abstract ByteBuffer serialize() throws IOException;
-
- /**
- * Deep serialize and write of a list object (currently only support list containing matrices)
- * @param lo a list object containing only matrices
- * @param output output data to write to
- */
- protected void serializeAndWriteListObject(ListObject lo, DataOutput output) throws IOException {
- validateListObject(lo);
- output.writeInt(lo.getLength()); //write list length
- output.writeBoolean(lo.isNamedList()); //write list named
- for (int i = 0; i < lo.getLength(); i++) {
- if (lo.isNamedList())
- output.writeUTF(lo.getName(i)); //write name
- ((MatrixObject) lo.getData().get(i))
- .acquireReadAndRelease().write(output); //write matrix
- }
- // Cleanup the list object
- // because it is transferred to remote worker in binary format
- ParamservUtils.cleanupListObject(lo);
- }
-
- protected ListObject readAndDeserialize(DataInput input) throws IOException {
- int listLen = input.readInt();
- List<Data> data = new ArrayList<>();
- List<String> names = input.readBoolean() ?
- new ArrayList<>() : null;
- for(int i=0; i<listLen; i++) {
- if( names != null )
- names.add(input.readUTF());
- MatrixBlock mb = new MatrixBlock();
- mb.readFields(input);
- data.add(ParamservUtils.newMatrixObject(mb, false));
- }
- return new ListObject(data, names);
- }
-
- /**
- * Get serialization size of a list object
- * (scheme: size|name|size|matrix)
- * @param lo list object
- * @return serialization size
- */
- protected int getExactSerializedSize(ListObject lo) {
- if( lo == null ) return 0;
- long result = 4 + 1; // list length and of named
- if (lo.isNamedList()) //size for names incl length
- result += lo.getNames().stream().mapToLong(s -> IOUtilFunctions.getUTFSize(s)).sum();
- result += lo.getData().stream().mapToLong(d ->
- ((MatrixObject)d).acquireReadAndRelease().getExactSizeOnDisk()).sum();
- if( result > Integer.MAX_VALUE )
- throw new DMLRuntimeException("Serialized size ("+result+") larger than Integer.MAX_VALUE.");
- return (int) result;
- }
-
- private void validateListObject(ListObject lo) {
- for (Data d : lo.getData()) {
- if (!(d instanceof MatrixObject)) {
- throw new DMLRuntimeException(String.format("Paramserv func:"
- + " Unsupported deep serialize of %s, which is not matrix.", d.getDebugName()));
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
deleted file mode 100644
index 010481e..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-
-import org.apache.sysml.runtime.util.ByteBufferDataInput;
-import org.apache.sysml.runtime.controlprogram.caching.CacheDataOutput;
-import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.runtime.io.IOUtilFunctions;
-
-public class PSRpcResponse extends PSRpcObject {
- public enum Type {
- SUCCESS,
- SUCCESS_EMPTY,
- ERROR,
- }
-
- private Type _status;
- private Object _data; // Could be list object or exception
-
- public PSRpcResponse(ByteBuffer buffer) throws IOException {
- deserialize(buffer);
- }
-
- public PSRpcResponse(Type status) {
- this(status, null);
- }
-
- public PSRpcResponse(Type status, Object data) {
- _status = status;
- _data = data;
- if( _status == Type.SUCCESS && data == null )
- _status = Type.SUCCESS_EMPTY;
- }
-
- public boolean isSuccessful() {
- return _status != Type.ERROR;
- }
-
- public String getErrorMessage() {
- return (String) _data;
- }
-
- public ListObject getResultModel() {
- return (ListObject) _data;
- }
-
- @Override
- public void deserialize(ByteBuffer buffer) throws IOException {
- ByteBufferDataInput dis = new ByteBufferDataInput(buffer);
- _status = Type.values()[dis.readInt()];
- switch (_status) {
- case SUCCESS:
- _data = readAndDeserialize(dis);
- break;
- case SUCCESS_EMPTY:
- break;
- case ERROR:
- _data = dis.readUTF();
- break;
- }
- }
-
- @Override
- public ByteBuffer serialize() throws IOException {
- int len = 4 + (_status==Type.SUCCESS ? getExactSerializedSize((ListObject)_data) :
- _status==Type.SUCCESS_EMPTY ? 0 : IOUtilFunctions.getUTFSize((String)_data));
- CacheDataOutput dos = new CacheDataOutput(len);
- dos.writeInt(_status.ordinal());
- switch (_status) {
- case SUCCESS:
- serializeAndWriteListObject((ListObject) _data, dos);
- break;
- case SUCCESS_EMPTY:
- break;
- case ERROR:
- dos.writeUTF(_data.toString());
- break;
- }
- return ByteBuffer.wrap(dos.getBytes());
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 6220bb6..83ec3f7 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -65,15 +65,15 @@ import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionScheme;
-import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitioner;
+import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
+import org.apache.sysml.runtime.controlprogram.paramserv.dp.LocalDataPartitioner;
import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSBody;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSWorker;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory;
+import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSBody;
+import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcFactory;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.matrix.operators.Operator;
@@ -350,7 +350,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
switch (mode) {
case LOCAL:
case REMOTE_SPARK:
- return new LocalParamServer(model, aggFunc, updateType, ec, workerNum);
+ return LocalParamServer.create(model, aggFunc, updateType, ec, workerNum);
default:
throw new DMLRuntimeException("Unsupported parameter server: "+mode.name());
}
@@ -379,9 +379,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
private void partitionLocally(PSScheme scheme, ExecutionContext ec, List<LocalPSWorker> workers) {
MatrixObject features = ec.getMatrixObject(getParam(PS_FEATURES));
MatrixObject labels = ec.getMatrixObject(getParam(PS_LABELS));
- DataPartitionScheme.Result result = new DataPartitioner(scheme).doPartitioning(workers.size(), features.acquireRead(), labels.acquireRead());
- features.release();
- labels.release();
+ DataPartitionLocalScheme.Result result = new LocalDataPartitioner(scheme).doPartitioning(workers.size(), features.acquireReadAndRelease(), labels.acquireReadAndRelease());
List<MatrixObject> pfs = result.pFeatures;
List<MatrixObject> pls = result.pLabels;
if (pfs.size() < workers.size()) {
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
index fc9d9b4..21e6bd3 100644
--- a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
@@ -69,7 +69,7 @@ import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSBody;
+import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSBody;
import org.apache.sysml.runtime.controlprogram.parfor.ParForBody;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.instructions.CPInstructionParser;
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java
index 0092aed..2f39c91 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java
@@ -22,8 +22,8 @@ package org.apache.sysml.test.integration.functions.paramserv;
import java.util.stream.IntStream;
import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionScheme;
-import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitioner;
+import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
+import org.apache.sysml.runtime.controlprogram.paramserv.dp.LocalDataPartitioner;
import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.util.DataConverter;
@@ -54,26 +54,26 @@ public abstract class BaseDataPartitionerTest {
return IntStream.range(from, to).mapToDouble(i -> (double) i).toArray();
}
- protected DataPartitionScheme.Result launchLocalDataPartitionerDC() {
- DataPartitioner dp = new DataPartitioner(Statement.PSScheme.DISJOINT_CONTIGUOUS);
+ protected DataPartitionLocalScheme.Result launchLocalDataPartitionerDC() {
+ LocalDataPartitioner dp = new LocalDataPartitioner(Statement.PSScheme.DISJOINT_CONTIGUOUS);
MatrixBlock[] mbs = generateData();
return dp.doPartitioning(WORKER_NUM, mbs[0], mbs[1]);
}
- protected DataPartitionScheme.Result launchLocalDataPartitionerDR(MatrixBlock[] mbs) {
+ protected DataPartitionLocalScheme.Result launchLocalDataPartitionerDR(MatrixBlock[] mbs) {
ParamservUtils.SEED = System.nanoTime();
- DataPartitioner dp = new DataPartitioner(Statement.PSScheme.DISJOINT_RANDOM);
+ LocalDataPartitioner dp = new LocalDataPartitioner(Statement.PSScheme.DISJOINT_RANDOM);
return dp.doPartitioning(WORKER_NUM, mbs[0], mbs[1]);
}
- protected DataPartitionScheme.Result launchLocalDataPartitionerDRR() {
- DataPartitioner dp = new DataPartitioner(Statement.PSScheme.DISJOINT_ROUND_ROBIN);
+ protected DataPartitionLocalScheme.Result launchLocalDataPartitionerDRR() {
+ LocalDataPartitioner dp = new LocalDataPartitioner(Statement.PSScheme.DISJOINT_ROUND_ROBIN);
MatrixBlock[] mbs = generateData();
return dp.doPartitioning(WORKER_NUM, mbs[0], mbs[1]);
}
- protected DataPartitionScheme.Result launchLocalDataPartitionerOR() {
- DataPartitioner dp = new DataPartitioner(Statement.PSScheme.OVERLAP_RESHUFFLE);
+ protected DataPartitionLocalScheme.Result launchLocalDataPartitionerOR() {
+ LocalDataPartitioner dp = new LocalDataPartitioner(Statement.PSScheme.OVERLAP_RESHUFFLE);
MatrixBlock[] mbs = generateData();
return dp.doPartitioning(WORKER_NUM, mbs[0], mbs[1]);
}