You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/08/04 02:16:16 UTC
[1/3] systemml git commit: [MINOR] Various paramserv refactorings and
code cleanups
Repository: systemml
Updated Branches:
refs/heads/master e11ae6af3 -> 382f847de
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/test/java/org/apache/sysml/test/integration/functions/paramserv/LocalDataPartitionerTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/LocalDataPartitionerTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/LocalDataPartitionerTest.java
index 1e4538a..4733406 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/LocalDataPartitionerTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/LocalDataPartitionerTest.java
@@ -23,7 +23,7 @@ import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
-import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionScheme;
+import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -36,7 +36,7 @@ public class LocalDataPartitionerTest extends BaseDataPartitionerTest {
@Test
public void testLocalDataPartitionerDC() {
- DataPartitionScheme.Result result = launchLocalDataPartitionerDC();
+ DataPartitionLocalScheme.Result result = launchLocalDataPartitionerDC();
Assert.assertEquals(WORKER_NUM, result.pFeatures.size());
Assert.assertEquals(WORKER_NUM, result.pLabels.size());
@@ -45,7 +45,7 @@ public class LocalDataPartitionerTest extends BaseDataPartitionerTest {
}
}
- private void assertDCResult(DataPartitionScheme.Result result, int workerID) {
+ private void assertDCResult(DataPartitionLocalScheme.Result result, int workerID) {
Assert.assertArrayEquals(generateExpectedData(workerID * (ROW_SIZE / WORKER_NUM) * COL_SIZE, (workerID + 1) * (ROW_SIZE / WORKER_NUM) * COL_SIZE), result.pFeatures.get(workerID).acquireRead().getDenseBlockValues(), 0);
Assert.assertArrayEquals(generateExpectedData(workerID * (ROW_SIZE / WORKER_NUM), (workerID + 1) * (ROW_SIZE / WORKER_NUM)), result.pLabels.get(workerID).acquireRead().getDenseBlockValues(), 0);
}
@@ -53,7 +53,7 @@ public class LocalDataPartitionerTest extends BaseDataPartitionerTest {
@Test
public void testLocalDataPartitionerDR() {
MatrixBlock[] mbs = generateData();
- DataPartitionScheme.Result result = launchLocalDataPartitionerDR(mbs);
+ DataPartitionLocalScheme.Result result = launchLocalDataPartitionerDR(mbs);
Assert.assertEquals(WORKER_NUM, result.pFeatures.size());
Assert.assertEquals(WORKER_NUM, result.pLabels.size());
@@ -82,7 +82,7 @@ public class LocalDataPartitionerTest extends BaseDataPartitionerTest {
@Test
public void testLocalDataPartitionerDRR() {
- DataPartitionScheme.Result result = launchLocalDataPartitionerDRR();
+ DataPartitionLocalScheme.Result result = launchLocalDataPartitionerDRR();
Assert.assertEquals(WORKER_NUM, result.pFeatures.size());
Assert.assertEquals(WORKER_NUM, result.pLabels.size());
@@ -91,7 +91,7 @@ public class LocalDataPartitionerTest extends BaseDataPartitionerTest {
}
}
- private void assertDRRResult(DataPartitionScheme.Result result, int workerID) {
+ private void assertDRRResult(DataPartitionLocalScheme.Result result, int workerID) {
Tuple2<double[], double[]> expected = generateExpectedData(workerID, WORKER_NUM, ROW_SIZE / WORKER_NUM);
Assert.assertArrayEquals(expected._1, result.pFeatures.get(workerID).acquireRead().getDenseBlockValues(), 0);
Assert.assertArrayEquals(expected._2, result.pLabels.get(workerID).acquireRead().getDenseBlockValues(), 0);
@@ -114,7 +114,7 @@ public class LocalDataPartitionerTest extends BaseDataPartitionerTest {
@Test
public void testLocalDataPartitionerOR() {
ParamservUtils.SEED = System.nanoTime();
- DataPartitionScheme.Result result = launchLocalDataPartitionerOR();
+ DataPartitionLocalScheme.Result result = launchLocalDataPartitionerOR();
Assert.assertEquals(WORKER_NUM, result.pFeatures.size());
Assert.assertEquals(WORKER_NUM, result.pLabels.size());
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
index 17bfa4c..464b0b1 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/RpcObjectTest.java
@@ -23,9 +23,9 @@ import java.io.IOException;
import java.util.Arrays;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
+import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcCall;
+import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcResponse;
import org.apache.sysml.runtime.instructions.cp.ListObject;
import org.junit.Assert;
import org.junit.Test;
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SparkDataPartitionerTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SparkDataPartitionerTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SparkDataPartitionerTest.java
index b0e4a27..8cae4a4 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SparkDataPartitionerTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/SparkDataPartitionerTest.java
@@ -26,7 +26,7 @@ import org.apache.sysml.api.DMLScript;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionScheme;
+import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.junit.Assert;
@@ -51,14 +51,14 @@ public class SparkDataPartitionerTest extends BaseDataPartitionerTest {
@Test
public void testSparkDataPartitionerDC() {
- DataPartitionScheme.Result localResult = launchLocalDataPartitionerDC();
+ DataPartitionLocalScheme.Result localResult = launchLocalDataPartitionerDC();
Map<Integer, Tuple2<MatrixBlock, MatrixBlock>> sparkResult = doPartitioning(Statement.PSScheme.DISJOINT_CONTIGUOUS);
// Compare the both
assertResult(localResult, sparkResult);
}
- private void assertResult(DataPartitionScheme.Result local, Map<Integer, Tuple2<MatrixBlock, MatrixBlock>> spark) {
+ private void assertResult(DataPartitionLocalScheme.Result local, Map<Integer, Tuple2<MatrixBlock, MatrixBlock>> spark) {
IntStream.range(0, WORKER_NUM).forEach(w -> {
Assert.assertArrayEquals(local.pFeatures.get(w).acquireRead().getDenseBlockValues(), spark.get(w)._1.getDenseBlockValues(), 0);
Assert.assertArrayEquals(local.pLabels.get(w).acquireRead().getDenseBlockValues(), spark.get(w)._2.getDenseBlockValues(), 0);
@@ -69,7 +69,7 @@ public class SparkDataPartitionerTest extends BaseDataPartitionerTest {
public void testSparkDataPartitionerDR() {
ParamservUtils.SEED = System.nanoTime();
MatrixBlock[] mbs = generateData();
- DataPartitionScheme.Result localResult = launchLocalDataPartitionerDR(mbs);
+ DataPartitionLocalScheme.Result localResult = launchLocalDataPartitionerDR(mbs);
Map<Integer, Tuple2<MatrixBlock, MatrixBlock>> sparkResult = doPartitioning(Statement.PSScheme.DISJOINT_RANDOM);
// Compare the both
@@ -78,7 +78,7 @@ public class SparkDataPartitionerTest extends BaseDataPartitionerTest {
@Test
public void testSparkDataPartitionerDRR() {
- DataPartitionScheme.Result localResult = launchLocalDataPartitionerDRR();
+ DataPartitionLocalScheme.Result localResult = launchLocalDataPartitionerDRR();
Map<Integer, Tuple2<MatrixBlock, MatrixBlock>> sparkResult = doPartitioning(Statement.PSScheme.DISJOINT_ROUND_ROBIN);
// Compare the both
@@ -88,7 +88,7 @@ public class SparkDataPartitionerTest extends BaseDataPartitionerTest {
@Test
public void testSparkDataPartitionerOR() {
ParamservUtils.SEED = System.nanoTime();
- DataPartitionScheme.Result localResult = launchLocalDataPartitionerOR();
+ DataPartitionLocalScheme.Result localResult = launchLocalDataPartitionerOR();
Map<Integer, Tuple2<MatrixBlock, MatrixBlock>> sparkResult = doPartitioning(Statement.PSScheme.OVERLAP_RESHUFFLE);
// Compare the both
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
index 35b0bd2..5ccda12 100644
--- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml
@@ -101,8 +101,8 @@ train = function(matrix[double] X, matrix[double] Y,
# Regularization
lambda = 5e-04
- # Create the model object
- modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+ # Create the model list
+ modelList = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
# Create the hyper parameter list
params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
@@ -110,15 +110,14 @@ train = function(matrix[double] X, matrix[double] Y,
# Use paramserv function
modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv.dml::aggregation", mode=mode, utype=utype, freq=freq, epochs=epochs, batchsize=batchsize, k=workers, scheme=scheme, hyperparams=params, checkpointing="NONE")
- W1 = as.matrix(modelList2["W1"])
- b1 = as.matrix(modelList2["b1"])
- W2 = as.matrix(modelList2["W2"])
- b2 = as.matrix(modelList2["b2"])
- W3 = as.matrix(modelList2["W3"])
- b3 = as.matrix(modelList2["b3"])
- W4 = as.matrix(modelList2["W4"])
- b4 = as.matrix(modelList2["b4"])
-
+ W1 = as.matrix(modelList2[1])
+ W2 = as.matrix(modelList2[2])
+ W3 = as.matrix(modelList2[3])
+ W4 = as.matrix(modelList2[4])
+ b1 = as.matrix(modelList2[5])
+ b2 = as.matrix(modelList2[6])
+ b3 = as.matrix(modelList2[7])
+ b4 = as.matrix(modelList2[8])
}
# Should always use 'features' (batch features), 'labels' (batch labels),
@@ -130,27 +129,25 @@ gradients = function(matrix[double] features,
list[unknown] model)
return (list[unknown] gradients) {
-# PB: not be able to get scalar from list
-
- C = as.scalar(hyperparams["C"])
- Hin = 28
- Win = 28
- Hf = 5
- Wf = 5
- stride = 1
- pad = 2
- lambda = 5e-04
- F1 = 32
- F2 = 64
- N3 = 512
- W1 = as.matrix(model["W1"])
- b1 = as.matrix(model["b1"])
- W2 = as.matrix(model["W2"])
- b2 = as.matrix(model["b2"])
- W3 = as.matrix(model["W3"])
- b3 = as.matrix(model["b3"])
- W4 = as.matrix(model["W4"])
- b4 = as.matrix(model["b4"])
+ C = as.integer(as.scalar(hyperparams["C"]))
+ Hin = as.integer(as.scalar(hyperparams["Hin"]))
+ Win = as.integer(as.scalar(hyperparams["Win"]))
+ Hf = as.integer(as.scalar(hyperparams["Hf"]))
+ Wf = as.integer(as.scalar(hyperparams["Wf"]))
+ stride = as.integer(as.scalar(hyperparams["stride"]))
+ pad = as.integer(as.scalar(hyperparams["pad"]))
+ lambda = as.double(as.scalar(hyperparams["lambda"]))
+ F1 = as.integer(as.scalar(hyperparams["F1"]))
+ F2 = as.integer(as.scalar(hyperparams["F2"]))
+ N3 = as.integer(as.scalar(hyperparams["N3"]))
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
# Compute forward pass
## layer 1: conv1 -> relu1 -> pool1
@@ -202,7 +199,7 @@ gradients = function(matrix[double] features,
dW3 = dW3 + dW3_reg
dW4 = dW4 + dW4_reg
- gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, db3=db3, db4=db4)
+ gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
}
# Should use the arguments named 'model', 'gradients', 'hyperparams'
@@ -211,33 +208,32 @@ aggregation = function(list[unknown] model,
list[unknown] gradients,
list[unknown] hyperparams)
return (list[unknown] modelResult) {
-
- W1 = as.matrix(model["W1"])
- W2 = as.matrix(model["W2"])
- W3 = as.matrix(model["W3"])
- W4 = as.matrix(model["W4"])
- b1 = as.matrix(model["b1"])
- b2 = as.matrix(model["b2"])
- b3 = as.matrix(model["b3"])
- b4 = as.matrix(model["b4"])
- dW1 = as.matrix(gradients["dW1"])
- dW2 = as.matrix(gradients["dW2"])
- dW3 = as.matrix(gradients["dW3"])
- dW4 = as.matrix(gradients["dW4"])
- db1 = as.matrix(gradients["db1"])
- db2 = as.matrix(gradients["db2"])
- db3 = as.matrix(gradients["db3"])
- db4 = as.matrix(gradients["db4"])
- vW1 = as.matrix(model["vW1"])
- vW2 = as.matrix(model["vW2"])
- vW3 = as.matrix(model["vW3"])
- vW4 = as.matrix(model["vW4"])
- vb1 = as.matrix(model["vb1"])
- vb2 = as.matrix(model["vb2"])
- vb3 = as.matrix(model["vb3"])
- vb4 = as.matrix(model["vb4"])
- lr = 0.01
- mu = 0.9
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
+ dW1 = as.matrix(gradients[1])
+ dW2 = as.matrix(gradients[2])
+ dW3 = as.matrix(gradients[3])
+ dW4 = as.matrix(gradients[4])
+ db1 = as.matrix(gradients[5])
+ db2 = as.matrix(gradients[6])
+ db3 = as.matrix(gradients[7])
+ db4 = as.matrix(gradients[8])
+ vW1 = as.matrix(model[9])
+ vW2 = as.matrix(model[10])
+ vW3 = as.matrix(model[11])
+ vW4 = as.matrix(model[12])
+ vb1 = as.matrix(model[13])
+ vb2 = as.matrix(model[14])
+ vb3 = as.matrix(model[15])
+ vb4 = as.matrix(model[16])
+ lr = as.double(as.scalar(hyperparams["lr"]))
+ mu = as.double(as.scalar(hyperparams["mu"]))
# Optimize with SGD w/ Nesterov momentum
[W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
@@ -249,7 +245,7 @@ aggregation = function(list[unknown] model,
[W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
[b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
- modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+ modelResult = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
}
predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size,
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
index a3677aa..e7056f0 100644
--- a/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
+++ b/src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml
@@ -101,7 +101,7 @@ train = function(matrix[double] X, matrix[double] Y,
lambda = 5e-04
# Create the model object
- modelList = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+ modelList = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
# Create the hyper parameter list
params = list(lr=lr, mu=mu, decay=decay, C=C, Hin=Hin, Win=Win, Hf=Hf, Wf=Wf, stride=stride, pad=pad, lambda=lambda, F1=F1, F2=F2, N3=N3)
@@ -109,14 +109,14 @@ train = function(matrix[double] X, matrix[double] Y,
# Use paramserv function
modelList2 = paramserv(model=modelList, features=X, labels=Y, val_features=X_val, val_labels=Y_val, upd="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::gradients", agg="./src/test/scripts/functions/paramserv/mnist_lenet_paramserv_minimum_version.dml::aggregation", mode="LOCAL", utype="BSP", epochs=epochs, hyperparams=params)
- W1 = as.matrix(modelList2["W1"])
- b1 = as.matrix(modelList2["b1"])
- W2 = as.matrix(modelList2["W2"])
- b2 = as.matrix(modelList2["b2"])
- W3 = as.matrix(modelList2["W3"])
- b3 = as.matrix(modelList2["b3"])
- W4 = as.matrix(modelList2["W4"])
- b4 = as.matrix(modelList2["b4"])
+ W1 = as.matrix(modelList2[1])
+ W2 = as.matrix(modelList2[2])
+ W3 = as.matrix(modelList2[3])
+ W4 = as.matrix(modelList2[4])
+ b1 = as.matrix(modelList2[5])
+ b2 = as.matrix(modelList2[6])
+ b3 = as.matrix(modelList2[7])
+ b4 = as.matrix(modelList2[8])
}
@@ -126,25 +126,25 @@ gradients = function(matrix[double] features,
list[unknown] model)
return (list[unknown] gradients) {
- C = 1
- Hin = 28
- Win = 28
- Hf = 5
- Wf = 5
- stride = 1
- pad = 2
- lambda = 5e-04
- F1 = 32
- F2 = 64
- N3 = 512
- W1 = as.matrix(model["W1"])
- b1 = as.matrix(model["b1"])
- W2 = as.matrix(model["W2"])
- b2 = as.matrix(model["b2"])
- W3 = as.matrix(model["W3"])
- b3 = as.matrix(model["b3"])
- W4 = as.matrix(model["W4"])
- b4 = as.matrix(model["b4"])
+ C = as.integer(as.scalar(hyperparams["C"]))
+ Hin = as.integer(as.scalar(hyperparams["Hin"]))
+ Win = as.integer(as.scalar(hyperparams["Win"]))
+ Hf = as.integer(as.scalar(hyperparams["Hf"]))
+ Wf = as.integer(as.scalar(hyperparams["Wf"]))
+ stride = as.integer(as.scalar(hyperparams["stride"]))
+ pad = as.integer(as.scalar(hyperparams["pad"]))
+ lambda = as.double(as.scalar(hyperparams["lambda"]))
+ F1 = as.integer(as.scalar(hyperparams["F1"]))
+ F2 = as.integer(as.scalar(hyperparams["F2"]))
+ N3 = as.integer(as.scalar(hyperparams["N3"]))
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
# Compute forward pass
## layer 1: conv1 -> relu1 -> pool1
@@ -196,41 +196,39 @@ gradients = function(matrix[double] features,
dW3 = dW3 + dW3_reg
dW4 = dW4 + dW4_reg
- gradients = list(dW1=dW1, dW2=dW2, dW3=dW3, dW4=dW4, db1=db1, db2=db2, db3=db3, db4=db4)
-
+ gradients = list(dW1, dW2, dW3, dW4, db1, db2, db3, db4)
}
aggregation = function(list[unknown] model,
list[unknown] gradients,
list[unknown] hyperparams)
return (list[unknown] modelResult) {
-
- W1 = as.matrix(model["W1"])
- W2 = as.matrix(model["W2"])
- W3 = as.matrix(model["W3"])
- W4 = as.matrix(model["W4"])
- b1 = as.matrix(model["b1"])
- b2 = as.matrix(model["b2"])
- b3 = as.matrix(model["b3"])
- b4 = as.matrix(model["b4"])
- dW1 = as.matrix(gradients["dW1"])
- dW2 = as.matrix(gradients["dW2"])
- dW3 = as.matrix(gradients["dW3"])
- dW4 = as.matrix(gradients["dW4"])
- db1 = as.matrix(gradients["db1"])
- db2 = as.matrix(gradients["db2"])
- db3 = as.matrix(gradients["db3"])
- db4 = as.matrix(gradients["db4"])
- vW1 = as.matrix(model["vW1"])
- vW2 = as.matrix(model["vW2"])
- vW3 = as.matrix(model["vW3"])
- vW4 = as.matrix(model["vW4"])
- vb1 = as.matrix(model["vb1"])
- vb2 = as.matrix(model["vb2"])
- vb3 = as.matrix(model["vb3"])
- vb4 = as.matrix(model["vb4"])
- lr = 0.01
- mu = 0.9
+ W1 = as.matrix(model[1])
+ W2 = as.matrix(model[2])
+ W3 = as.matrix(model[3])
+ W4 = as.matrix(model[4])
+ b1 = as.matrix(model[5])
+ b2 = as.matrix(model[6])
+ b3 = as.matrix(model[7])
+ b4 = as.matrix(model[8])
+ dW1 = as.matrix(gradients[1])
+ dW2 = as.matrix(gradients[2])
+ dW3 = as.matrix(gradients[3])
+ dW4 = as.matrix(gradients[4])
+ db1 = as.matrix(gradients[5])
+ db2 = as.matrix(gradients[6])
+ db3 = as.matrix(gradients[7])
+ db4 = as.matrix(gradients[8])
+ vW1 = as.matrix(model[9])
+ vW2 = as.matrix(model[10])
+ vW3 = as.matrix(model[11])
+ vW4 = as.matrix(model[12])
+ vb1 = as.matrix(model[13])
+ vb2 = as.matrix(model[14])
+ vb3 = as.matrix(model[15])
+ vb4 = as.matrix(model[16])
+ lr = as.double(as.scalar(hyperparams["lr"]))
+ mu = as.double(as.scalar(hyperparams["mu"]))
# Optimize with SGD w/ Nesterov momentum
[W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
@@ -242,7 +240,7 @@ aggregation = function(list[unknown] model,
[W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
[b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
- modelResult = list(W1=W1, b1=b1, W2=W2, b2=b2, W3=W3, b3=b3, W4=W4, b4=b4, vW1=vW1, vW2=vW2, vW3=vW3, vW4=vW4, vb1=vb1, vb2=vb2, vb3=vb3, vb4=vb4)
+ modelResult = list(W1, W2, W3, W4, b1, b2, b3, b4, vW1, vW2, vW3, vW4, vb1, vb2, vb3, vb4)
}
predict = function(matrix[double] X, int C, int Hin, int Win, int batch_size,
[3/3] systemml git commit: [MINOR] Various paramserv refactorings and
code cleanups
Posted by mb...@apache.org.
[MINOR] Various paramserv refactorings and code cleanups
Closes #814.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/382f847d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/382f847d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/382f847d
Branch: refs/heads/master
Commit: 382f847de6e33cdb5386b5eb5912eb5da0dff8d6
Parents: e11ae6a
Author: EdgarLGB <gu...@atos.net>
Authored: Fri Aug 3 19:17:15 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Fri Aug 3 19:17:15 2018 -0700
----------------------------------------------------------------------
.../controlprogram/paramserv/DCScheme.java | 61 -------
.../controlprogram/paramserv/DRRScheme.java | 57 -------
.../controlprogram/paramserv/DRScheme.java | 62 -------
.../paramserv/DataPartitionScheme.java | 40 -----
.../paramserv/DataPartitioner.java | 49 ------
.../controlprogram/paramserv/LocalPSWorker.java | 40 ++---
.../paramserv/LocalParamServer.java | 6 +-
.../controlprogram/paramserv/ORScheme.java | 61 -------
.../controlprogram/paramserv/PSWorker.java | 1 +
.../controlprogram/paramserv/ParamServer.java | 40 ++---
.../paramserv/ParamservUtils.java | 66 ++++++--
.../controlprogram/paramserv/SparkPSBody.java | 44 +++++
.../controlprogram/paramserv/SparkPSProxy.java | 84 ++++++++++
.../controlprogram/paramserv/SparkPSWorker.java | 158 +++++++++++++++++
.../paramserv/dp/DCLocalScheme.java | 62 +++++++
.../paramserv/dp/DCSparkScheme.java | 47 ++++++
.../paramserv/dp/DRLocalScheme.java | 63 +++++++
.../paramserv/dp/DRRLocalScheme.java | 58 +++++++
.../paramserv/dp/DRRSparkScheme.java | 45 +++++
.../paramserv/dp/DRSparkScheme.java | 69 ++++++++
.../paramserv/dp/DataPartitionLocalScheme.java | 40 +++++
.../paramserv/dp/DataPartitionSparkScheme.java | 76 +++++++++
.../dp/DataPartitionerSparkAggregator.java | 66 ++++++++
.../dp/DataPartitionerSparkMapper.java | 70 ++++++++
.../paramserv/dp/LocalDataPartitioner.java | 52 ++++++
.../paramserv/dp/ORLocalScheme.java | 62 +++++++
.../paramserv/dp/ORSparkScheme.java | 60 +++++++
.../paramserv/dp/SparkDataPartitioner.java | 106 ++++++++++++
.../controlprogram/paramserv/rpc/PSRpcCall.java | 86 ++++++++++
.../paramserv/rpc/PSRpcFactory.java | 61 +++++++
.../paramserv/rpc/PSRpcHandler.java | 95 +++++++++++
.../paramserv/rpc/PSRpcObject.java | 107 ++++++++++++
.../paramserv/rpc/PSRpcResponse.java | 101 +++++++++++
.../paramserv/spark/DCSparkScheme.java | 47 ------
.../paramserv/spark/DRRSparkScheme.java | 45 -----
.../paramserv/spark/DRSparkScheme.java | 69 --------
.../spark/DataPartitionSparkScheme.java | 76 ---------
.../spark/DataPartitionerSparkAggregator.java | 66 --------
.../spark/DataPartitionerSparkMapper.java | 70 --------
.../paramserv/spark/ORSparkScheme.java | 60 -------
.../paramserv/spark/SparkDataPartitioner.java | 106 ------------
.../paramserv/spark/SparkPSBody.java | 44 -----
.../paramserv/spark/SparkPSProxy.java | 85 ----------
.../paramserv/spark/SparkPSWorker.java | 168 -------------------
.../paramserv/spark/rpc/PSRpcCall.java | 86 ----------
.../paramserv/spark/rpc/PSRpcFactory.java | 57 -------
.../paramserv/spark/rpc/PSRpcHandler.java | 95 -----------
.../paramserv/spark/rpc/PSRpcObject.java | 107 ------------
.../paramserv/spark/rpc/PSRpcResponse.java | 101 -----------
.../cp/ParamservBuiltinCPInstruction.java | 16 +-
.../sysml/runtime/util/ProgramConverter.java | 2 +-
.../paramserv/BaseDataPartitionerTest.java | 20 +--
.../paramserv/LocalDataPartitionerTest.java | 14 +-
.../functions/paramserv/RpcObjectTest.java | 6 +-
.../paramserv/SparkDataPartitionerTest.java | 12 +-
.../paramserv/mnist_lenet_paramserv.dml | 118 +++++++------
.../mnist_lenet_paramserv_minimum_version.dml | 114 +++++++------
57 files changed, 1851 insertions(+), 1828 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DCScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DCScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DCScheme.java
deleted file mode 100644
index 00aaa21..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DCScheme.java
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.stream.Collectors;
-
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-/**
- * Disjoint_Contiguous data partitioner:
- *
- * for each worker, use a right indexing
- * operation X[beg:end,] to obtain contiguous,
- * non-overlapping partitions of rows.
- */
-public class DCScheme extends DataPartitionScheme {
-
- public static List<MatrixBlock> partition(int k, MatrixBlock mb) {
- List<MatrixBlock> list = new ArrayList<>();
- long stepSize = (long) Math.ceil((double) mb.getNumRows() / k);
- long begin = 1;
- while (begin < mb.getNumRows()) {
- long end = Math.min(begin - 1 + stepSize, mb.getNumRows());
- MatrixBlock pmo = ParamservUtils.sliceMatrixBlock(mb, begin, end);
- list.add(pmo);
- begin = end + 1;
- }
- return list;
- }
-
- private List<MatrixObject> doPartitioning(int k, MatrixBlock mb) {
- return partition(k, mb).stream().map(ParamservUtils::newMatrixObject).collect(Collectors.toList());
- }
-
- @Override
- public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) {
- List<MatrixObject> pfs = doPartitioning(workersNum, features);
- List<MatrixObject> pls = doPartitioning(workersNum, labels);
- return new Result(pfs, pls);
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRRScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRRScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRRScheme.java
deleted file mode 100644
index 90c62d6..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRRScheme.java
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv;
-
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import java.util.stream.LongStream;
-
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.util.DataConverter;
-
-/**
- * Disjoint_Round_Robin data partitioner:
- * for each worker, use a permutation multiply
- * or simpler a removeEmpty such as removeEmpty
- * (target=X, margin=rows, select=(seq(1,nrow(X))%%k)==id)
- */
-public class DRRScheme extends DataPartitionScheme {
-
- public static MatrixBlock removeEmpty(MatrixBlock mb, int k, int workerId) {
- double[] data = LongStream.range(0, mb.getNumRows()).mapToDouble(l -> l % k == workerId ? 1 : 0).toArray();
- MatrixBlock select = DataConverter.convertToMatrixBlock(data, true);
- return mb.removeEmptyOperations(new MatrixBlock(), true, true, select);
- }
-
- private MatrixObject internalRemoveEmpty(MatrixBlock mb, int k, int workerId) {
- MatrixObject result = ParamservUtils.newMatrixObject(removeEmpty(mb, k, workerId));
- result.enableCleanup(false);
- return result;
- }
-
- @Override
- public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) {
- List<MatrixObject> pfs = IntStream.range(0, workersNum).mapToObj(i -> internalRemoveEmpty(features, workersNum, i)).collect(Collectors.toList());
- List<MatrixObject> pls = IntStream.range(0, workersNum).mapToObj(i -> internalRemoveEmpty(labels, workersNum, i)).collect(Collectors.toList());
- return new Result(pfs, pls);
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRScheme.java
deleted file mode 100644
index 062a7ab..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DRScheme.java
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv;
-
-import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED;
-
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.instructions.InstructionUtils;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-/**
- * Data partitioner Disjoint_Random:
- * for each worker, use a permutation multiply P[beg:end,] %*% X,
- * where P is constructed for example with P=table(seq(1,nrow(X)),sample(nrow(X), nrow(X))),
- * i.e., sampling without replacement to ensure disjointness.
- */
-public class DRScheme extends DataPartitionScheme {
-
- private List<MatrixBlock> partition(int k, MatrixBlock mb, MatrixBlock permutation) {
- int batchSize = (int) Math.ceil((double) mb.getNumRows() / k);
- return IntStream.range(0, k).mapToObj(i -> {
- int begin = i * batchSize;
- int end = Math.min((i + 1) * batchSize, mb.getNumRows());
- MatrixBlock slicedPerm = permutation.slice(begin, end - 1);
- return slicedPerm.aggregateBinaryOperations(slicedPerm, mb, new MatrixBlock(), InstructionUtils.getMatMultOperator(k));
- }).collect(Collectors.toList());
- }
-
- private List<MatrixObject> internalDoPartitioning(int k, MatrixBlock mb, MatrixBlock permutation) {
- return partition(k, mb, permutation).stream().map(ParamservUtils::newMatrixObject).collect(Collectors.toList());
- }
-
- @Override
- public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) {
- // Generate a single permutation matrix (workers use slices)
- MatrixBlock permutation = ParamservUtils.generatePermutation(features.getNumRows(), SEED);
- List<MatrixObject> pfs = internalDoPartitioning(workersNum, features, permutation);
- List<MatrixObject> pls = internalDoPartitioning(workersNum, labels, permutation);
- return new Result(pfs, pls);
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionScheme.java
deleted file mode 100644
index f2ea0aa..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitionScheme.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv;
-
-import java.util.List;
-
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-public abstract class DataPartitionScheme {
-
- public final class Result {
- public final List<MatrixObject> pFeatures;
- public final List<MatrixObject> pLabels;
-
- public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels) {
- this.pFeatures = pFeatures;
- this.pLabels = pLabels;
- }
- }
-
- public abstract Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels);
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java
deleted file mode 100644
index 3f28cd1..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/DataPartitioner.java
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv;
-
-import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-public class DataPartitioner {
-
- private DataPartitionScheme _scheme;
-
- public DataPartitioner(Statement.PSScheme scheme) {
- switch (scheme) {
- case DISJOINT_CONTIGUOUS:
- _scheme = new DCScheme();
- break;
- case DISJOINT_ROUND_ROBIN:
- _scheme = new DRRScheme();
- break;
- case DISJOINT_RANDOM:
- _scheme = new DRScheme();
- break;
- case OVERLAP_RESHUFFLE:
- _scheme = new ORScheme();
- break;
- }
- }
-
- public DataPartitionScheme.Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) {
- return _scheme.doPartitioning(workersNum, features, labels);
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index f76fddb..04050b2 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -54,15 +54,17 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
incWorkerNumber();
try {
long dataSize = _features.getNumRows();
- int totalIter = (int) Math.ceil((double) dataSize / _batchSize);
+ int batchIter = (int) Math.ceil((double) dataSize / _batchSize);
switch (_freq) {
case BATCH:
- computeBatch(dataSize, totalIter);
+ computeBatch(dataSize, batchIter);
break;
case EPOCH:
- computeEpoch(dataSize, totalIter);
+ computeEpoch(dataSize, batchIter);
break;
+ default:
+ throw new DMLRuntimeException(String.format("%s not support update frequency %s", getWorkerName(), _freq));
}
if (LOG.isDebugEnabled()) {
@@ -74,25 +76,23 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
return null;
}
- private void computeEpoch(long dataSize, int totalIter) {
+ private void computeEpoch(long dataSize, int batchIter) {
for (int i = 0; i < _epochs; i++) {
// Pull the global parameters from ps
ListObject params = pullModel();
ListObject accGradients = null;
- for (int j = 0; j < totalIter; j++) {
- _ec.setVariable(Statement.PS_MODEL, params);
-
- ListObject gradients = computeGradients(dataSize, totalIter, i, j);
+ for (int j = 0; j < batchIter; j++) {
+ ListObject gradients = computeGradients(params, dataSize, batchIter, i, j);
+ boolean localUpdate = j < batchIter - 1;
// Accumulate the intermediate gradients
- accGradients = ParamservUtils.accrueGradients(accGradients, gradients);
+ accGradients = ParamservUtils.accrueGradients(accGradients, gradients, !localUpdate);
// Update the local model with gradients
- if( j < totalIter - 1 )
- params = updateModel(params, gradients, i, j, totalIter);
- ParamservUtils.cleanupListObject(_ec, gradients);
-
+ if(localUpdate)
+ params = updateModel(params, gradients, i, j, batchIter);
+
accNumBatches(1);
}
@@ -107,7 +107,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
}
}
- private ListObject updateModel(ListObject globalParams, ListObject gradients, int i, int j, int totalIter) {
+ private ListObject updateModel(ListObject globalParams, ListObject gradients, int i, int j, int batchIter) {
Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
globalParams = _ps.updateLocalModel(_ec, gradients, globalParams);
@@ -117,7 +117,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
if (LOG.isDebugEnabled()) {
LOG.debug(String.format("%s: local global parameter [size:%d kb] updated. "
+ "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]",
- getWorkerName(), globalParams.getDataSize(), i + 1, _epochs, j + 1, totalIter));
+ getWorkerName(), globalParams.getDataSize(), i + 1, _epochs, j + 1, batchIter));
}
return globalParams;
}
@@ -127,8 +127,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
for (int j = 0; j < totalIter; j++) {
ListObject globalParams = pullModel();
- _ec.setVariable(Statement.PS_MODEL, globalParams);
- ListObject gradients = computeGradients(dataSize, totalIter, i, j);
+ ListObject gradients = computeGradients(globalParams, dataSize, totalIter, i, j);
// Push the gradients to ps
pushGradients(gradients);
@@ -163,7 +162,8 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
}
}
- private ListObject computeGradients(long dataSize, int totalIter, int i, int j) {
+ private ListObject computeGradients(ListObject params, long dataSize, int batchIter, int i, int j) {
+ _ec.setVariable(Statement.PS_MODEL, params);
long begin = j * _batchSize + 1;
long end = Math.min((j + 1) * _batchSize, dataSize);
@@ -180,7 +180,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
LOG.debug(String.format("%s: got batch data [size:%d kb] of index from %d to %d [last index: %d]. "
+ "[Epoch:%d Total epoch:%d Iteration:%d Total iteration:%d]", getWorkerName(),
bFeatures.getDataSize() / 1024 + bLabels.getDataSize() / 1024, begin, end, dataSize, i + 1, _epochs,
- j + 1, totalIter));
+ j + 1, batchIter));
}
// Invoke the update function
@@ -189,7 +189,7 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
accGradientComputeTime(tGrad);
// Get the gradients
- ListObject gradients = (ListObject) _ec.getVariable(_output.getName());
+ ListObject gradients = _ec.getListObject(_output.getName());
ParamservUtils.cleanupData(_ec, Statement.PS_FEATURES);
ParamservUtils.cleanupData(_ec, Statement.PS_LABELS);
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
index 0c73acb..a2904fe 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalParamServer.java
@@ -30,7 +30,11 @@ public class LocalParamServer extends ParamServer {
super();
}
- public LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
+ public static LocalParamServer create(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
+ return new LocalParamServer(model, aggFunc, updateType, ec, workerNum);
+ }
+
+ private LocalParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
super(model, aggFunc, updateType, ec, workerNum);
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ORScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ORScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ORScheme.java
deleted file mode 100644
index 2692efa..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ORScheme.java
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv;
-
-import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED;
-
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.instructions.InstructionUtils;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-/**
- * Data partitioner Overlap_Reshuffle:
- * for each worker, use a new permutation multiply P %*% X,
- * where P is constructed for example with P=table(seq(1,nrow(X),sample(nrow(X), nrow(X))))
- */
-public class ORScheme extends DataPartitionScheme {
-
- public static List<MatrixBlock> partition(int k, MatrixBlock mb, List<MatrixBlock> permutations) {
- return IntStream.range(0, k).mapToObj(i -> {
- MatrixBlock permutation = permutations.get(i);
- return permutation.aggregateBinaryOperations(permutation, mb, new MatrixBlock(),
- InstructionUtils.getMatMultOperator(k));
- }).collect(Collectors.toList());
- }
-
- private List<MatrixObject> doPartitioning(int k, MatrixBlock mb, List<MatrixBlock> permutations) {
- return partition(k, mb, permutations).stream().map(ParamservUtils::newMatrixObject).collect(Collectors.toList());
- }
-
- @Override
- public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) {
- // Generate a different permutation matrix for each worker
- List<MatrixBlock> permutations = IntStream.range(0, workersNum)
- .mapToObj(i -> ParamservUtils.generatePermutation(features.getNumRows(), SEED+i))
- .collect(Collectors.toList());
- List<MatrixObject> pfs = doPartitioning(workersNum, features, permutations);
- List<MatrixObject> pls = doPartitioning(workersNum, labels, permutations);
- return new Result(pfs, pls);
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
index 5f2d552..63600d1 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/PSWorker.java
@@ -36,6 +36,7 @@ import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
+// TODO use the validate features and labels to calculate the model precision when training
public abstract class PSWorker implements Serializable
{
private static final long serialVersionUID = -3510485051178200118L;
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
index 0f5f70d..2b2249e 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -81,15 +81,10 @@ public abstract class ParamServer
setupAggFunc(_ec, aggFunc);
// broadcast initial model
- try {
- broadcastModel();
- }
- catch (InterruptedException e) {
- throw new DMLRuntimeException("Param server: failed to broadcast the initial model.", e);
- }
+ broadcastModel(true);
}
- public void setupAggFunc(ExecutionContext ec, String aggFunc) {
+ protected void setupAggFunc(ExecutionContext ec, String aggFunc) {
String[] cfn = ParamservUtils.getCompleteFuncName(aggFunc, PS_FUNC_PREFIX);
String ns = cfn[0];
String fname = cfn[1];
@@ -140,11 +135,9 @@ public abstract class ParamServer
// Accumulate the intermediate gradients
if( ACCRUE_BSP_GRADIENTS )
- _accGradients = ParamservUtils.accrueGradients(
- _accGradients, gradients, true);
+ _accGradients = ParamservUtils.accrueGradients(_accGradients, gradients, true);
else
updateGlobalModel(gradients);
- ParamservUtils.cleanupListObject(_ec, gradients);
if (allFinished()) {
// Update the global model with accrued gradients
@@ -155,7 +148,7 @@ public abstract class ParamServer
// Broadcast the updated model
resetFinishedStates();
- broadcastModel();
+ broadcastModel(true);
if (LOG.isDebugEnabled())
LOG.debug("Global parameter is broadcasted successfully.");
}
@@ -199,7 +192,7 @@ public abstract class ParamServer
_inst.processInstruction(ec);
// Get the new model
- ListObject newModel = (ListObject) ec.getVariable(_outputName);
+ ListObject newModel = ec.getListObject(_outputName);
// Clean up the list according to the data referencing status
ParamservUtils.cleanupListObject(ec, Statement.PS_MODEL, newModel.getStatus());
@@ -218,23 +211,26 @@ public abstract class ParamServer
private void setFinishedState(int workerID) {
_finishedStates[workerID] = true;
}
-
- private void broadcastModel() throws InterruptedException {
- Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null;
-
- //broadcast copy of the model to all workers, cleaned up by workers
- for (BlockingQueue<ListObject> q : _modelMap.values())
- q.put(ParamservUtils.copyList(_model));
- if (DMLScript.STATISTICS)
- Statistics.accPSModelBroadcastTime((long) tBroad.stop());
+ /**
+ * Broadcast the model for all workers
+ */
+ private void broadcastModel(boolean par) {
+ IntStream stream = IntStream.range(0, _modelMap.size());
+ (par ? stream.parallel() : stream).forEach(workerID -> {
+ try {
+ broadcastModel(workerID);
+ } catch (InterruptedException e) {
+ throw new DMLRuntimeException("Paramserv func: some error occurred when broadcasting model", e);
+ }
+ });
}
private void broadcastModel(int workerID) throws InterruptedException {
Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null;
//broadcast copy of model to specific worker, cleaned up by worker
- _modelMap.get(workerID).put(ParamservUtils.copyList(_model));
+ _modelMap.get(workerID).put(ParamservUtils.copyList(_model, false));
if (DMLScript.STATISTICS)
Statistics.accPSModelBroadcastTime((long) tBroad.stop());
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
index e9292d1..f8b5dda 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -58,8 +58,8 @@ import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkAggregator;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.DataPartitionerSparkMapper;
+import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionerSparkAggregator;
+import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionerSparkMapper;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.cp.Data;
@@ -86,12 +86,10 @@ public class ParamservUtils {
* Deep copy the list object
*
* @param lo list object
+ * @param cleanup clean up the given list object
* @return a new copied list object
*/
- public static ListObject copyList(ListObject lo) {
- if (lo.getLength() == 0) {
- return lo;
- }
+ public static ListObject copyList(ListObject lo, boolean cleanup) {
List<Data> newData = IntStream.range(0, lo.getLength()).mapToObj(i -> {
Data oldData = lo.slice(i);
if (oldData instanceof MatrixObject)
@@ -101,7 +99,10 @@ public class ParamservUtils {
else
return oldData;
}).collect(Collectors.toList());
- return new ListObject(newData, lo.getNames());
+ ListObject result = new ListObject(newData, lo.getNames());
+ if (cleanup)
+ ParamservUtils.cleanupListObject(lo);
+ return result;
}
/**
@@ -197,6 +198,12 @@ public class ParamservUtils {
return mb.slice((int) rl - 1, (int) rh - 1);
}
+ /**
+ * Generate the permutation
+ * @param numEntries permutation size
+ * @param seed seed used to generate random number
+ * @return permutation matrix
+ */
public static MatrixBlock generatePermutation(int numEntries, long seed) {
// Create a sequence and sample w/o replacement
// (no need to materialize the sequence because ctable only uses its meta data)
@@ -208,6 +215,12 @@ public class ParamservUtils {
new MatrixBlock(numEntries, numEntries, true));
}
+ /**
+ * Get the namespace and function name of a given physical func name
+ * @param funcName physical func name (e.g., "ns:func")
+ * @param prefix prefix
+ * @return an string array of size 2 where array[0] is namespace and array[1] is name
+ */
public static String[] getCompleteFuncName(String funcName, String prefix) {
String[] keys = DMLProgram.splitFunctionKey(funcName);
String ns = (keys.length==2) ? keys[0] : null;
@@ -373,9 +386,9 @@ public class ParamservUtils {
Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
// Get input RDD
JavaPairRDD<MatrixIndexes, MatrixBlock> featuresRDD = (JavaPairRDD<MatrixIndexes, MatrixBlock>)
- sec.getRDDHandleForMatrixObject(features, InputInfo.BinaryBlockInputInfo);
+ sec.getRDDHandleForMatrixObject(features, InputInfo.BinaryBlockInputInfo);
JavaPairRDD<MatrixIndexes, MatrixBlock> labelsRDD = (JavaPairRDD<MatrixIndexes, MatrixBlock>)
- sec.getRDDHandleForMatrixObject(labels, InputInfo.BinaryBlockInputInfo);
+ sec.getRDDHandleForMatrixObject(labels, InputInfo.BinaryBlockInputInfo);
DataPartitionerSparkMapper mapper = new DataPartitionerSparkMapper(scheme, workerNum, sec, (int) features.getNumRows());
JavaPairRDD<Integer, Tuple2<MatrixBlock, MatrixBlock>> result = ParamservUtils
@@ -408,21 +421,38 @@ public class ParamservUtils {
return result;
}
- public static ListObject accrueGradients(ListObject accGradients, ListObject gradients) {
- return accrueGradients(accGradients, gradients, false);
+ /**
+ * Accumulate the given gradients into the accrued gradients
+ *
+ * @param accGradients accrued gradients list object
+ * @param gradients given gradients list object
+ * @param cleanup clean up the given gradients list object
+ * @return new accrued gradients list object
+ */
+ public static ListObject accrueGradients(ListObject accGradients, ListObject gradients, boolean cleanup) {
+ return accrueGradients(accGradients, gradients, false, cleanup);
}
-
- public static ListObject accrueGradients(ListObject accGradients, ListObject gradients, boolean par) {
+
+ /**
+ * Accumulate the given gradients into the accrued gradients
+ *
+ * @param accGradients accrued gradients list object
+ * @param gradients given gradients list object
+ * @param par parallel execution
+ * @param cleanup clean up the given gradients list object
+ * @return new accrued gradients list object
+ */
+ public static ListObject accrueGradients(ListObject accGradients, ListObject gradients, boolean par, boolean cleanup) {
if (accGradients == null)
- return ParamservUtils.copyList(gradients);
+ return ParamservUtils.copyList(gradients, cleanup);
IntStream range = IntStream.range(0, accGradients.getLength());
(par ? range.parallel() : range).forEach(i -> {
- MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireRead();
- MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireRead();
+ MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireReadAndRelease();
+ MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireReadAndRelease();
mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2);
- ((MatrixObject) accGradients.getData().get(i)).release();
- ((MatrixObject) gradients.getData().get(i)).release();
});
+ if (cleanup)
+ ParamservUtils.cleanupListObject(gradients);
return accGradients;
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSBody.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSBody.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSBody.java
new file mode 100644
index 0000000..58690a6
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSBody.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+
+/**
+ * Wrapper class containing all needed for launching spark remote worker
+ */
+public class SparkPSBody {
+
+ private ExecutionContext _ec;
+
+ public SparkPSBody() {}
+
+ public SparkPSBody(ExecutionContext ec) {
+ _ec = ec;
+ }
+
+ public ExecutionContext getEc() {
+ return _ec;
+ }
+
+ public void setEc(ExecutionContext ec) {
+ this._ec = ec;
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSProxy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSProxy.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSProxy.java
new file mode 100644
index 0000000..fd88b83
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSProxy.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import static org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcObject.PULL;
+import static org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcObject.PUSH;
+
+import java.io.IOException;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.util.LongAccumulator;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcCall;
+import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcResponse;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public class SparkPSProxy extends ParamServer {
+
+ private final TransportClient _client;
+ private final long _rpcTimeout;
+ private final LongAccumulator _aRPC;
+
+ public SparkPSProxy(TransportClient client, long rpcTimeout, LongAccumulator aRPC) {
+ super();
+ _client = client;
+ _rpcTimeout = rpcTimeout;
+ _aRPC = aRPC;
+ }
+
+ private void accRpcRequestTime(Timing tRpc) {
+ if (DMLScript.STATISTICS)
+ _aRPC.add((long) tRpc.stop());
+ }
+
+ @Override
+ public void push(int workerID, ListObject value) {
+ Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
+ PSRpcResponse response;
+ try {
+ response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, value).serialize(), _rpcTimeout));
+ } catch (IOException e) {
+ throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients.", workerID), e);
+ }
+ accRpcRequestTime(tRpc);
+ if (!response.isSuccessful()) {
+ throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients. \n%s", workerID, response.getErrorMessage()));
+ }
+ }
+
+ @Override
+ public ListObject pull(int workerID) {
+ Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
+ PSRpcResponse response;
+ try {
+ response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, null).serialize(), _rpcTimeout));
+ } catch (IOException e) {
+ throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models.", workerID), e);
+ }
+ accRpcRequestTime(tRpc);
+ if (!response.isSuccessful()) {
+ throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models. \n%s", workerID, response.getErrorMessage()));
+ }
+ return response.getResultModel();
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSWorker.java
new file mode 100644
index 0000000..bc8fc9e
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/SparkPSWorker.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.function.VoidFunction;
+import org.apache.spark.util.LongAccumulator;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.codegen.CodegenUtils;
+import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcFactory;
+import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils;
+import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.util.ProgramConverter;
+
+import scala.Tuple2;
+
+public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>> {
+
+ private static final long serialVersionUID = -8674739573419648732L;
+
+ private final String _program;
+ private final HashMap<String, byte[]> _clsMap;
+ private final SparkConf _conf;
+ private final int _port; // rpc port
+ private final String _aggFunc;
+ private final LongAccumulator _aSetup; // accumulator for setup time
+ private final LongAccumulator _aWorker; // accumulator for worker number
+ private final LongAccumulator _aUpdate; // accumulator for model update
+ private final LongAccumulator _aIndex; // accumulator for batch indexing
+ private final LongAccumulator _aGrad; // accumulator for gradients computing
+ private final LongAccumulator _aRPC; // accumulator for rpc request
+ private final LongAccumulator _nBatches; //number of executed batches
+ private final LongAccumulator _nEpochs; //number of executed epoches
+
+ public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches, LongAccumulator aEpochs) {
+ _updFunc = updFunc;
+ _aggFunc = aggFunc;
+ _freq = freq;
+ _epochs = epochs;
+ _batchSize = batchSize;
+ _program = program;
+ _clsMap = clsMap;
+ _conf = conf;
+ _port = port;
+ _aSetup = aSetup;
+ _aWorker = aWorker;
+ _aUpdate = aUpdate;
+ _aIndex = aIndex;
+ _aGrad = aGrad;
+ _aRPC = aRPC;
+ _nBatches = aBatches;
+ _nEpochs = aEpochs;
+ }
+
+ @Override
+ public String getWorkerName() {
+ return String.format("Spark worker_%d", _workerID);
+ }
+
+ @Override
+ public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception {
+ Timing tSetup = new Timing(true);
+ configureWorker(input);
+ accSetupTime(tSetup);
+
+ call(); // Launch the worker
+ }
+
+ private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws IOException {
+ _workerID = input._1;
+
+ // Initialize codegen class cache (before program parsing)
+ for (Map.Entry<String, byte[]> e : _clsMap.entrySet()) {
+ CodegenUtils.getClassSync(e.getKey(), e.getValue());
+ }
+
+ // Deserialize the body to initialize the execution context
+ SparkPSBody body = ProgramConverter.parseSparkPSBody(_program, _workerID);
+ _ec = body.getEc();
+
+ // Initialize the buffer pool and register it in the jvm shutdown hook in order to be cleanuped at the end
+ RemoteParForUtils.setupBufferPool(_workerID);
+
+ // Create the ps proxy
+ _ps = PSRpcFactory.createSparkPSProxy(_conf, _port, _aRPC);
+
+ // Initialize the update function
+ setupUpdateFunction(_updFunc, _ec);
+
+ // Initialize the agg function
+ _ps.setupAggFunc(_ec, _aggFunc);
+
+ // Lazy initialize the matrix of features and labels
+ setFeatures(ParamservUtils.newMatrixObject(input._2._1, false));
+ setLabels(ParamservUtils.newMatrixObject(input._2._2, false));
+ }
+
+
+ @Override
+ protected void incWorkerNumber() {
+ _aWorker.add(1);
+ }
+
+ @Override
+ protected void accLocalModelUpdateTime(Timing time) {
+ if( time != null )
+ _aUpdate.add((long) time.stop());
+ }
+
+ @Override
+ protected void accBatchIndexingTime(Timing time) {
+ if( time != null )
+ _aIndex.add((long) time.stop());
+ }
+
+ @Override
+ protected void accGradientComputeTime(Timing time) {
+ if( time != null )
+ _aGrad.add((long) time.stop());
+ }
+
+ @Override
+ protected void accNumEpochs(int n) {
+ _nEpochs.add(n);
+ }
+
+ @Override
+ protected void accNumBatches(int n) {
+ _nBatches.add(n);
+ }
+
+ private void accSetupTime(Timing time) {
+ if( time != null )
+ _aSetup.add((long) time.stop());
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCLocalScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCLocalScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCLocalScheme.java
new file mode 100644
index 0000000..9a3e502
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCLocalScheme.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+/**
+ * Disjoint_Contiguous data partitioner:
+ *
+ * for each worker, use a right indexing
+ * operation X[beg:end,] to obtain contiguous,
+ * non-overlapping partitions of rows.
+ */
+public class DCLocalScheme extends DataPartitionLocalScheme {
+
+ public static List<MatrixBlock> partition(int k, MatrixBlock mb) {
+ List<MatrixBlock> list = new ArrayList<>();
+ long stepSize = (long) Math.ceil((double) mb.getNumRows() / k);
+ long begin = 1;
+ while (begin < mb.getNumRows()) {
+ long end = Math.min(begin - 1 + stepSize, mb.getNumRows());
+ MatrixBlock pmo = ParamservUtils.sliceMatrixBlock(mb, begin, end);
+ list.add(pmo);
+ begin = end + 1;
+ }
+ return list;
+ }
+
+ private List<MatrixObject> doPartitioning(int k, MatrixBlock mb) {
+ return partition(k, mb).stream().map(ParamservUtils::newMatrixObject).collect(Collectors.toList());
+ }
+
+ @Override
+ public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) {
+ List<MatrixObject> pfs = doPartitioning(workersNum, features);
+ List<MatrixObject> pls = doPartitioning(workersNum, labels);
+ return new Result(pfs, pls);
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCSparkScheme.java
new file mode 100644
index 0000000..f42e0b6
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DCSparkScheme.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import java.util.List;
+
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+import scala.Tuple2;
+
+/**
+ * Spark Disjoint_Contiguous data partitioner:
+ * <p>
+ * For each row, find out the shifted place according to the workerID indicator
+ */
+public class DCSparkScheme extends DataPartitionSparkScheme {
+
+ private static final long serialVersionUID = -2786906947020788787L;
+
+ protected DCSparkScheme() {
+ // No-args constructor used for deserialization
+ }
+
+ @Override
+ public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
+ List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = nonShuffledPartition(rblkID, features);
+ List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = nonShuffledPartition(rblkID, labels);
+ return new Result(pfs, pls);
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRLocalScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRLocalScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRLocalScheme.java
new file mode 100644
index 0000000..464be99
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRLocalScheme.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED;
+
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+/**
+ * Data partitioner Disjoint_Random:
+ * for each worker, use a permutation multiply P[beg:end,] %*% X,
+ * where P is constructed for example with P=table(seq(1,nrow(X)),sample(nrow(X), nrow(X))),
+ * i.e., sampling without replacement to ensure disjointness.
+ */
+public class DRLocalScheme extends DataPartitionLocalScheme {
+
+ private List<MatrixBlock> partition(int k, MatrixBlock mb, MatrixBlock permutation) {
+ int batchSize = (int) Math.ceil((double) mb.getNumRows() / k);
+ return IntStream.range(0, k).mapToObj(i -> {
+ int begin = i * batchSize;
+ int end = Math.min((i + 1) * batchSize, mb.getNumRows());
+ MatrixBlock slicedPerm = permutation.slice(begin, end - 1);
+ return slicedPerm.aggregateBinaryOperations(slicedPerm, mb, new MatrixBlock(), InstructionUtils.getMatMultOperator(k));
+ }).collect(Collectors.toList());
+ }
+
+ private List<MatrixObject> internalDoPartitioning(int k, MatrixBlock mb, MatrixBlock permutation) {
+ return partition(k, mb, permutation).stream().map(ParamservUtils::newMatrixObject).collect(Collectors.toList());
+ }
+
+ @Override
+ public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) {
+ // Generate a single permutation matrix (workers use slices)
+ MatrixBlock permutation = ParamservUtils.generatePermutation(features.getNumRows(), SEED);
+ List<MatrixObject> pfs = internalDoPartitioning(workersNum, features, permutation);
+ List<MatrixObject> pls = internalDoPartitioning(workersNum, labels, permutation);
+ return new Result(pfs, pls);
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRLocalScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRLocalScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRLocalScheme.java
new file mode 100644
index 0000000..2061903
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRLocalScheme.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.LongStream;
+
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.util.DataConverter;
+
+/**
+ * Disjoint_Round_Robin data partitioner:
+ * for each worker, use a permutation multiply
+ * or simpler a removeEmpty such as removeEmpty
+ * (target=X, margin=rows, select=(seq(1,nrow(X))%%k)==id)
+ */
+public class DRRLocalScheme extends DataPartitionLocalScheme {
+
+ public static MatrixBlock removeEmpty(MatrixBlock mb, int k, int workerId) {
+ double[] data = LongStream.range(0, mb.getNumRows()).mapToDouble(l -> l % k == workerId ? 1 : 0).toArray();
+ MatrixBlock select = DataConverter.convertToMatrixBlock(data, true);
+ return mb.removeEmptyOperations(new MatrixBlock(), true, true, select);
+ }
+
+ private MatrixObject internalRemoveEmpty(MatrixBlock mb, int k, int workerId) {
+ MatrixObject result = ParamservUtils.newMatrixObject(removeEmpty(mb, k, workerId));
+ result.enableCleanup(false);
+ return result;
+ }
+
+ @Override
+ public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) {
+ List<MatrixObject> pfs = IntStream.range(0, workersNum).mapToObj(i -> internalRemoveEmpty(features, workersNum, i)).collect(Collectors.toList());
+ List<MatrixObject> pls = IntStream.range(0, workersNum).mapToObj(i -> internalRemoveEmpty(labels, workersNum, i)).collect(Collectors.toList());
+ return new Result(pfs, pls);
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRSparkScheme.java
new file mode 100644
index 0000000..025f774
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRRSparkScheme.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import java.util.List;
+
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+import scala.Tuple2;
+
+/**
+ * Spark Disjoint_Round_Robin data partitioner:
+ */
+public class DRRSparkScheme extends DataPartitionSparkScheme {
+
+ private static final long serialVersionUID = -3130831851505549672L;
+
+ protected DRRSparkScheme() {
+ // No-args constructor used for deserialization
+ }
+
+ @Override
+ public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
+ List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = nonShuffledPartition(rblkID, features);
+ List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = nonShuffledPartition(rblkID, labels);
+ return new Result(pfs, pls);
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRSparkScheme.java
new file mode 100644
index 0000000..df61af9
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DRSparkScheme.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+import scala.Tuple2;
+
+/**
+ * Spark data partitioner Disjoint_Random:
+ *
+ * For the current row block, find all the shifted place for each row (WorkerID => (row block ID, matrix)
+ */
+public class DRSparkScheme extends DataPartitionSparkScheme {
+
+ private static final long serialVersionUID = -7655310624144544544L;
+
+ protected DRSparkScheme() {
+ // No-args constructor used for deserialization
+ }
+
+ @Override
+ public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
+ List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = partition(rblkID, features);
+ List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = partition(rblkID, labels);
+ return new Result(pfs, pls);
+ }
+
+ private List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> partition(int rblkID, MatrixBlock mb) {
+ MatrixBlock partialPerm = _globalPerms.get(0).getBlock(rblkID, 1);
+
+ // For each row, find out the shifted place
+ return IntStream.range(0, mb.getNumRows()).mapToObj(r -> {
+ MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1);
+ long shiftedPosition = (long) partialPerm.getValue(r, 0);
+
+ // Get the shifted block and position
+ int shiftedBlkID = (int) (shiftedPosition / OptimizerUtils.DEFAULT_BLOCKSIZE + 1);
+
+ MatrixBlock indicator = _workerIndicator.getBlock(shiftedBlkID, 1);
+ int workerID = (int) indicator.getValue((int) shiftedPosition / OptimizerUtils.DEFAULT_BLOCKSIZE, 0);
+ return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB));
+ }).collect(Collectors.toList());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionLocalScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionLocalScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionLocalScheme.java
new file mode 100644
index 0000000..8d03345
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionLocalScheme.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import java.util.List;
+
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+public abstract class DataPartitionLocalScheme {
+
+ public final class Result {
+ public final List<MatrixObject> pFeatures;
+ public final List<MatrixObject> pLabels;
+
+ public Result(List<MatrixObject> pFeatures, List<MatrixObject> pLabels) {
+ this.pFeatures = pFeatures;
+ this.pLabels = pLabels;
+ }
+ }
+
+ public abstract Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels);
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionSparkScheme.java
new file mode 100644
index 0000000..7992ac8
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionSparkScheme.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import java.io.Serializable;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.LongStream;
+
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+import scala.Tuple2;
+
+public abstract class DataPartitionSparkScheme implements Serializable {
+
+ protected final class Result {
+ protected final List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pFeatures; // WorkerID => (rowID, matrix)
+ protected final List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pLabels;
+
+ protected Result(List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pFeatures, List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pLabels) {
+ this.pFeatures = pFeatures;
+ this.pLabels = pLabels;
+ }
+ }
+
+ private static final long serialVersionUID = -3462829818083371171L;
+
+ protected List<PartitionedBroadcast<MatrixBlock>> _globalPerms; // a list of global permutations
+ protected PartitionedBroadcast<MatrixBlock> _workerIndicator; // a matrix indicating to which worker the given row belongs
+
+ protected void setGlobalPermutation(List<PartitionedBroadcast<MatrixBlock>> gps) {
+ _globalPerms = gps;
+ }
+
+ protected void setWorkerIndicator(PartitionedBroadcast<MatrixBlock> wi) {
+ _workerIndicator = wi;
+ }
+
+ /**
+ * Do non-reshuffled data partitioning according to worker indicator
+ * @param rblkID row block ID
+ * @param mb Matrix
+ * @return list of tuple (workerID, (row block ID, matrix row))
+ */
+ protected List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> nonShuffledPartition(int rblkID, MatrixBlock mb) {
+ MatrixBlock indicator = _workerIndicator.getBlock(rblkID, 1);
+ return LongStream.range(0, mb.getNumRows()).mapToObj(r -> {
+ int workerID = (int) indicator.getValue((int) r, 0);
+ MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1);
+ long shiftedPosition = r + (rblkID - 1) * OptimizerUtils.DEFAULT_BLOCKSIZE;
+ return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB));
+ }).collect(Collectors.toList());
+ }
+
+ public abstract Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels);
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkAggregator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkAggregator.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkAggregator.java
new file mode 100644
index 0000000..0314ccf
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkAggregator.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import java.io.Serializable;
+import java.util.LinkedList;
+
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+import scala.Tuple2;
+
+public class DataPartitionerSparkAggregator implements PairFunction<Tuple2<Integer,LinkedList<Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>>>>, Integer, Tuple2<MatrixBlock, MatrixBlock>>, Serializable {
+
+ private static final long serialVersionUID = -1245300852709085117L;
+ private long _fcol;
+ private long _lcol;
+
+ public DataPartitionerSparkAggregator() {
+
+ }
+
+ public DataPartitionerSparkAggregator(long fcol, long lcol) {
+ _fcol = fcol;
+ _lcol = lcol;
+ }
+
+ /**
+ * Row-wise combine the matrix
+ * @param input workerID => ordered list [(rowBlockID, (features, labels))]
+ * @return workerID => [(features, labels)]
+ * @throws Exception Some exception
+ */
+ @Override
+ public Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> call(Tuple2<Integer, LinkedList<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>> input) throws Exception {
+ MatrixBlock fmb = new MatrixBlock(input._2.size(), (int) _fcol, false);
+ MatrixBlock lmb = new MatrixBlock(input._2.size(), (int) _lcol, false);
+
+ for (int i = 0; i < input._2.size(); i++) {
+ MatrixBlock tmpFMB = input._2.get(i)._2._1;
+ MatrixBlock tmpLMB = input._2.get(i)._2._2;
+ // Row-wise aggregation
+ fmb = fmb.leftIndexingOperations(tmpFMB, i, i, 0, (int) _fcol - 1, fmb, MatrixObject.UpdateType.INPLACE_PINNED);
+ lmb = lmb.leftIndexingOperations(tmpLMB, i, i, 0, (int) _lcol - 1, lmb, MatrixObject.UpdateType.INPLACE_PINNED);
+ }
+ return new Tuple2<>(input._1, new Tuple2<>(fmb, lmb));
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkMapper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkMapper.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkMapper.java
new file mode 100644
index 0000000..bd30121
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/DataPartitionerSparkMapper.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import java.io.Serializable;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+
+import org.apache.spark.api.java.function.PairFlatMapFunction;
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+import scala.Tuple2;
+
+public class DataPartitionerSparkMapper implements PairFlatMapFunction<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>, Integer, Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable {
+
+ private static final long serialVersionUID = 1710721606050403296L;
+ private int _workersNum;
+
+ private SparkDataPartitioner _dp;
+
+ protected DataPartitionerSparkMapper() {
+ // No-args constructor used for deserialization
+ }
+
+ public DataPartitionerSparkMapper(Statement.PSScheme scheme, int workersNum, SparkExecutionContext sec, int numEntries) {
+ _workersNum = workersNum;
+ _dp = new SparkDataPartitioner(scheme, sec, numEntries, workersNum);
+ }
+
+ /**
+ * Do data partitioning
+ * @param input RowBlockID => (features, labels)
+ * @return WorkerID => (rowBlockID, (single row features, single row labels))
+ * @throws Exception Some exception
+ */
+ @Override
+ public Iterator<Tuple2<Integer, Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>> call(Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>> input)
+ throws Exception {
+ List<Tuple2<Integer, Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>>>> partitions = new LinkedList<>();
+ MatrixBlock features = input._2._1;
+ MatrixBlock labels = input._2._2;
+ DataPartitionSparkScheme.Result result = _dp.doPartitioning(_workersNum, features, labels, input._1);
+ for (int i = 0; i < result.pFeatures.size(); i++) {
+ Tuple2<Integer, Tuple2<Long, MatrixBlock>> ft = result.pFeatures.get(i);
+ Tuple2<Integer, Tuple2<Long, MatrixBlock>> lt = result.pLabels.get(i);
+ partitions.add(new Tuple2<>(ft._1, new Tuple2<>(ft._2._1, new Tuple2<>(ft._2._2, lt._2._2))));
+ }
+ return partitions.iterator();
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/LocalDataPartitioner.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/LocalDataPartitioner.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/LocalDataPartitioner.java
new file mode 100644
index 0000000..68cf9b6
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/LocalDataPartitioner.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+public class LocalDataPartitioner {
+
+ private DataPartitionLocalScheme _scheme;
+
+ public LocalDataPartitioner(Statement.PSScheme scheme) {
+ switch (scheme) {
+ case DISJOINT_CONTIGUOUS:
+ _scheme = new DCLocalScheme();
+ break;
+ case DISJOINT_ROUND_ROBIN:
+ _scheme = new DRRLocalScheme();
+ break;
+ case DISJOINT_RANDOM:
+ _scheme = new DRLocalScheme();
+ break;
+ case OVERLAP_RESHUFFLE:
+ _scheme = new ORLocalScheme();
+ break;
+ default:
+ throw new DMLRuntimeException(String.format("LocalDataPartitioner: not support data partition scheme '%s'", scheme));
+ }
+ }
+
+ public DataPartitionLocalScheme.Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) {
+ return _scheme.doPartitioning(workersNum, features, labels);
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORLocalScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORLocalScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORLocalScheme.java
new file mode 100644
index 0000000..b7d8b97
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORLocalScheme.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED;
+
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+/**
+ * Data partitioner Overlap_Reshuffle:
+ * for each worker, use a new permutation multiply P %*% X,
+ * where P is constructed for example with P=table(seq(1,nrow(X),sample(nrow(X), nrow(X))))
+ */
+public class ORLocalScheme extends DataPartitionLocalScheme {
+
+ public static List<MatrixBlock> partition(int k, MatrixBlock mb, List<MatrixBlock> permutations) {
+ return IntStream.range(0, k).mapToObj(i -> {
+ MatrixBlock permutation = permutations.get(i);
+ return permutation.aggregateBinaryOperations(permutation, mb, new MatrixBlock(),
+ InstructionUtils.getMatMultOperator(k));
+ }).collect(Collectors.toList());
+ }
+
+ private List<MatrixObject> doPartitioning(int k, MatrixBlock mb, List<MatrixBlock> permutations) {
+ return partition(k, mb, permutations).stream().map(ParamservUtils::newMatrixObject).collect(Collectors.toList());
+ }
+
+ @Override
+ public Result doPartitioning(int workersNum, MatrixBlock features, MatrixBlock labels) {
+ // Generate a different permutation matrix for each worker
+ List<MatrixBlock> permutations = IntStream.range(0, workersNum)
+ .mapToObj(i -> ParamservUtils.generatePermutation(features.getNumRows(), SEED+i))
+ .collect(Collectors.toList());
+ List<MatrixObject> pfs = doPartitioning(workersNum, features, permutations);
+ List<MatrixObject> pls = doPartitioning(workersNum, labels, permutations);
+ return new Result(pfs, pls);
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORSparkScheme.java
new file mode 100644
index 0000000..08b49b0
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/ORSparkScheme.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+import scala.Tuple2;
+
+/**
+ * Spark data partitioner Overlap_Reshuffle:
+ *
+ */
+public class ORSparkScheme extends DataPartitionSparkScheme {
+
+ private static final long serialVersionUID = 6867567406403580311L;
+
+ protected ORSparkScheme() {
+ // No-args constructor used for deserialization
+ }
+
+ @Override
+ public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
+ List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = partition(numWorkers, rblkID, features);
+ List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = partition(numWorkers, rblkID, labels);
+ return new Result(pfs, pls);
+ }
+
+ private List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> partition(int numWorkers, int rblkID, MatrixBlock mb) {
+ return IntStream.range(0, numWorkers).boxed().flatMap(workerID -> {
+ MatrixBlock partialPerm = _globalPerms.get(workerID).getBlock(rblkID, 1);
+ return IntStream.range(0, mb.getNumRows()).mapToObj(r -> {
+ MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1);
+ long shiftedPosition = (long) partialPerm.getValue(r, 0);
+ return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB));
+ });
+ }).collect(Collectors.toList());
+ }
+}
[2/3] systemml git commit: [MINOR] Various paramserv refactorings and
code cleanups
Posted by mb...@apache.org.
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/SparkDataPartitioner.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/SparkDataPartitioner.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/SparkDataPartitioner.java
new file mode 100644
index 0000000..031150b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/dp/SparkDataPartitioner.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.dp;
+
+import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import org.apache.sysml.parser.Statement;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.util.DataConverter;
+
+public class SparkDataPartitioner implements Serializable {
+
+ private static final long serialVersionUID = 6841548626711057448L;
+ private DataPartitionSparkScheme _scheme;
+
+ protected SparkDataPartitioner(Statement.PSScheme scheme, SparkExecutionContext sec, int numEntries, int numWorkers) {
+ switch (scheme) {
+ case DISJOINT_CONTIGUOUS:
+ _scheme = new DCSparkScheme();
+ // Create the worker id indicator
+ createDCIndicator(sec, numWorkers, numEntries);
+ break;
+ case DISJOINT_ROUND_ROBIN:
+ _scheme = new DRRSparkScheme();
+ // Create the worker id indicator
+ createDRIndicator(sec, numWorkers, numEntries);
+ break;
+ case DISJOINT_RANDOM:
+ _scheme = new DRSparkScheme();
+ // Create the global permutation
+ createGlobalPermutations(sec, numEntries, 1);
+ // Create the worker id indicator
+ createDCIndicator(sec, numWorkers, numEntries);
+ break;
+ case OVERLAP_RESHUFFLE:
+ _scheme = new ORSparkScheme();
+ // Create the global permutation seperately for each worker
+ createGlobalPermutations(sec, numEntries, numWorkers);
+ break;
+ }
+ }
+
+ private void createDRIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) {
+ double[] vector = IntStream.range(0, numEntries).mapToDouble(n -> n % numWorkers).toArray();
+ MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
+ _scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB)));
+ }
+
+ private void createDCIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) {
+ double[] vector = new double[numEntries];
+ int batchSize = (int) Math.ceil((double) numEntries / numWorkers);
+ for (int i = 1; i < numWorkers; i++) {
+ int begin = batchSize * i;
+ int end = Math.min(begin + batchSize, numEntries);
+ Arrays.fill(vector, begin, end, i);
+ }
+ MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
+ _scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB)));
+ }
+
+ private void createGlobalPermutations(SparkExecutionContext sec, int numEntries, int numPerm) {
+ List<PartitionedBroadcast<MatrixBlock>> perms = IntStream.range(0, numPerm).mapToObj(i -> {
+ MatrixBlock perm = MatrixBlock.sampleOperations(numEntries, numEntries, false, SEED+i);
+ // Create the source-target id vector from the permutation ranging from 1 to number of entries
+ double[] vector = new double[numEntries];
+ for (int j = 0; j < perm.getDenseBlockValues().length; j++) {
+ vector[(int) perm.getDenseBlockValues()[j] - 1] = j;
+ }
+ MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
+ return sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB));
+ }).collect(Collectors.toList());
+ _scheme.setGlobalPermutation(perms);
+ }
+
+ public DataPartitionSparkScheme.Result doPartitioning(int numWorkers, MatrixBlock features, MatrixBlock labels,
+ long rowID) {
+ // Set the rowID in order to get the according permutation
+ return _scheme.doPartitioning(numWorkers, (int) rowID, features, labels);
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcCall.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcCall.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcCall.java
new file mode 100644
index 0000000..8b0540b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcCall.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.rpc;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.CacheDataOutput;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.util.ByteBufferDataInput;
+
+public class PSRpcCall extends PSRpcObject {
+
+ private int _method;
+ private int _workerID;
+ private ListObject _data;
+
+ public PSRpcCall(int method, int workerID, ListObject data) {
+ _method = method;
+ _workerID = workerID;
+ _data = data;
+ }
+
+ public PSRpcCall(ByteBuffer buffer) throws IOException {
+ deserialize(buffer);
+ }
+
+ public int getMethod() {
+ return _method;
+ }
+
+ public int getWorkerID() {
+ return _workerID;
+ }
+
+ public ListObject getData() {
+ return _data;
+ }
+
+ public void deserialize(ByteBuffer buffer) throws IOException {
+ ByteBufferDataInput dis = new ByteBufferDataInput(buffer);
+ _method = dis.readInt();
+ validateMethod(_method);
+ _workerID = dis.readInt();
+ if (dis.available() > 1)
+ _data = readAndDeserialize(dis);
+ }
+
+ public ByteBuffer serialize() throws IOException {
+ int len = 8 + getExactSerializedSize(_data);
+ CacheDataOutput dos = new CacheDataOutput(len);
+ dos.writeInt(_method);
+ dos.writeInt(_workerID);
+ if (_data != null)
+ serializeAndWriteListObject(_data, dos);
+ return ByteBuffer.wrap(dos.getBytes());
+ }
+
+ private void validateMethod(int method) {
+ switch (method) {
+ case PUSH:
+ case PULL:
+ break;
+ default:
+ throw new DMLRuntimeException("PSRpcCall: only support rpc method 'push' or 'pull'");
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcFactory.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcFactory.java
new file mode 100644
index 0000000..a7db756
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcFactory.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.rpc;
+
+import java.io.IOException;
+import java.util.Collections;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.netty.SparkTransportConf;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.util.LongAccumulator;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSProxy;
+
+public class PSRpcFactory {
+
+ private static final String MODULE_NAME = "ps";
+
+ private static TransportContext createTransportContext(SparkConf conf, LocalParamServer ps) {
+ TransportConf tc = SparkTransportConf.fromSparkConf(conf, MODULE_NAME, 0);
+ PSRpcHandler handler = new PSRpcHandler(ps);
+ return new TransportContext(tc, handler);
+ }
+
+ /**
+ * Create and start the server
+ * @return server
+ */
+ public static TransportServer createServer(SparkConf conf, LocalParamServer ps, String host) {
+ TransportContext context = createTransportContext(conf, ps);
+ return context.createServer(host, 0, Collections.emptyList()); // bind rpc to an ephemeral port
+ }
+
+ public static SparkPSProxy createSparkPSProxy(SparkConf conf, int port, LongAccumulator aRPC) throws IOException {
+ long rpcTimeout = conf.contains("spark.rpc.askTimeout") ?
+ conf.getTimeAsMs("spark.rpc.askTimeout") :
+ conf.getTimeAsMs("spark.network.timeout", "120s");
+ String host = conf.get("spark.driver.host");
+ TransportContext context = createTransportContext(conf, new LocalParamServer());
+ return new SparkPSProxy(context.createClientFactory().createClient(host, port), rpcTimeout, aRPC);
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcHandler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcHandler.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcHandler.java
new file mode 100644
index 0000000..cf8de6d
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcHandler.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.rpc;
+
+import static org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcCall.PULL;
+import static org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcCall.PUSH;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import org.apache.commons.lang.exception.ExceptionUtils;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.server.OneForOneStreamManager;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
+import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcResponse.Type;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+
+public final class PSRpcHandler extends RpcHandler {
+
+ private LocalParamServer _server;
+
+ protected PSRpcHandler(LocalParamServer server) {
+ _server = server;
+ }
+
+ @Override
+ public void receive(TransportClient client, ByteBuffer buffer, RpcResponseCallback callback) {
+ PSRpcCall call;
+ try {
+ call = new PSRpcCall(buffer);
+ } catch (IOException e) {
+ throw new DMLRuntimeException("PSRpcHandler: some error occurred when deserializing the rpc call.", e);
+ }
+ PSRpcResponse response = null;
+ switch (call.getMethod()) {
+ case PUSH:
+ try {
+ _server.push(call.getWorkerID(), call.getData());
+ response = new PSRpcResponse(Type.SUCCESS_EMPTY);
+ } catch (DMLRuntimeException exception) {
+ response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception));
+ } finally {
+ try {
+ callback.onSuccess(response.serialize());
+ } catch (IOException e) {
+ throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e);
+ }
+ }
+ break;
+ case PULL:
+ ListObject data;
+ try {
+ data = _server.pull(call.getWorkerID());
+ response = new PSRpcResponse(Type.SUCCESS, data);
+ } catch (DMLRuntimeException exception) {
+ response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception));
+ } finally {
+ try {
+ callback.onSuccess(response.serialize());
+ } catch (IOException e) {
+ throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e);
+ }
+ }
+ break;
+ default:
+ throw new DMLRuntimeException(String.format("Does not support the rpc call for method %s", call.getMethod()));
+ }
+ }
+
+ @Override
+ public StreamManager getStreamManager() {
+ return new OneForOneStreamManager();
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcObject.java
new file mode 100644
index 0000000..38d80a2
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcObject.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.rpc;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
+import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.io.IOUtilFunctions;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+
+public abstract class PSRpcObject {
+
+ public static final int PUSH = 1;
+ public static final int PULL = 2;
+
+ public abstract void deserialize(ByteBuffer buffer) throws IOException;
+
+ public abstract ByteBuffer serialize() throws IOException;
+
+ /**
+ * Deep serialize and write of a list object (currently only support list containing matrices)
+ * @param lo a list object containing only matrices
+ * @param output output data to write to
+ */
+ protected void serializeAndWriteListObject(ListObject lo, DataOutput output) throws IOException {
+ validateListObject(lo);
+ output.writeInt(lo.getLength()); //write list length
+ output.writeBoolean(lo.isNamedList()); //write list named
+ for (int i = 0; i < lo.getLength(); i++) {
+ if (lo.isNamedList())
+ output.writeUTF(lo.getName(i)); //write name
+ ((MatrixObject) lo.getData().get(i))
+ .acquireReadAndRelease().write(output); //write matrix
+ }
+ // Cleanup the list object
+ // because it is transferred to remote worker in binary format
+ ParamservUtils.cleanupListObject(lo);
+ }
+
+ protected ListObject readAndDeserialize(DataInput input) throws IOException {
+ int listLen = input.readInt();
+ List<Data> data = new ArrayList<>();
+ List<String> names = input.readBoolean() ?
+ new ArrayList<>() : null;
+ for(int i=0; i<listLen; i++) {
+ if( names != null )
+ names.add(input.readUTF());
+ MatrixBlock mb = new MatrixBlock();
+ mb.readFields(input);
+ data.add(ParamservUtils.newMatrixObject(mb, false));
+ }
+ return new ListObject(data, names);
+ }
+
+ /**
+ * Get serialization size of a list object
+ * (scheme: size|name|size|matrix)
+ * @param lo list object
+ * @return serialization size
+ */
+ protected int getExactSerializedSize(ListObject lo) {
+ if( lo == null ) return 0;
+ long result = 4 + 1; // list length and of named
+ if (lo.isNamedList()) //size for names incl length
+ result += lo.getNames().stream().mapToLong(s -> IOUtilFunctions.getUTFSize(s)).sum();
+ result += lo.getData().stream().mapToLong(d ->
+ ((MatrixObject)d).acquireReadAndRelease().getExactSizeOnDisk()).sum();
+ if( result > Integer.MAX_VALUE )
+ throw new DMLRuntimeException("Serialized size ("+result+") larger than Integer.MAX_VALUE.");
+ return (int) result;
+ }
+
+ private void validateListObject(ListObject lo) {
+ for (Data d : lo.getData()) {
+ if (!(d instanceof MatrixObject)) {
+ throw new DMLRuntimeException(String.format("Paramserv func:"
+ + " Unsupported deep serialize of %s, which is not matrix.", d.getDebugName()));
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcResponse.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcResponse.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcResponse.java
new file mode 100644
index 0000000..68e1dd1
--- /dev/null
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/rpc/PSRpcResponse.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.controlprogram.paramserv.rpc;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import org.apache.sysml.runtime.util.ByteBufferDataInput;
+import org.apache.sysml.runtime.controlprogram.caching.CacheDataOutput;
+import org.apache.sysml.runtime.instructions.cp.ListObject;
+import org.apache.sysml.runtime.io.IOUtilFunctions;
+
+public class PSRpcResponse extends PSRpcObject {
+ public enum Type {
+ SUCCESS,
+ SUCCESS_EMPTY,
+ ERROR,
+ }
+
+ private Type _status;
+ private Object _data; // Could be list object or exception
+
+ public PSRpcResponse(ByteBuffer buffer) throws IOException {
+ deserialize(buffer);
+ }
+
+ public PSRpcResponse(Type status) {
+ this(status, null);
+ }
+
+ public PSRpcResponse(Type status, Object data) {
+ _status = status;
+ _data = data;
+ if( _status == Type.SUCCESS && data == null )
+ _status = Type.SUCCESS_EMPTY;
+ }
+
+ public boolean isSuccessful() {
+ return _status != Type.ERROR;
+ }
+
+ public String getErrorMessage() {
+ return (String) _data;
+ }
+
+ public ListObject getResultModel() {
+ return (ListObject) _data;
+ }
+
+ @Override
+ public void deserialize(ByteBuffer buffer) throws IOException {
+ ByteBufferDataInput dis = new ByteBufferDataInput(buffer);
+ _status = Type.values()[dis.readInt()];
+ switch (_status) {
+ case SUCCESS:
+ _data = readAndDeserialize(dis);
+ break;
+ case SUCCESS_EMPTY:
+ break;
+ case ERROR:
+ _data = dis.readUTF();
+ break;
+ }
+ }
+
+ @Override
+ public ByteBuffer serialize() throws IOException {
+ int len = 4 + (_status==Type.SUCCESS ? getExactSerializedSize((ListObject)_data) :
+ _status==Type.SUCCESS_EMPTY ? 0 : IOUtilFunctions.getUTFSize((String)_data));
+ CacheDataOutput dos = new CacheDataOutput(len);
+ dos.writeInt(_status.ordinal());
+ switch (_status) {
+ case SUCCESS:
+ serializeAndWriteListObject((ListObject) _data, dos);
+ break;
+ case SUCCESS_EMPTY:
+ break;
+ case ERROR:
+ dos.writeUTF(_data.toString());
+ break;
+ }
+ return ByteBuffer.wrap(dos.getBytes());
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DCSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DCSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DCSparkScheme.java
deleted file mode 100644
index 666b891..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DCSparkScheme.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.util.List;
-
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-/**
- * Spark Disjoint_Contiguous data partitioner:
- * <p>
- * For each row, find out the shifted place according to the workerID indicator
- */
-public class DCSparkScheme extends DataPartitionSparkScheme {
-
- private static final long serialVersionUID = -2786906947020788787L;
-
- protected DCSparkScheme() {
- // No-args constructor used for deserialization
- }
-
- @Override
- public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = nonShuffledPartition(rblkID, features);
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = nonShuffledPartition(rblkID, labels);
- return new Result(pfs, pls);
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRRSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRRSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRRSparkScheme.java
deleted file mode 100644
index 7683251..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRRSparkScheme.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.util.List;
-
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-/**
- * Spark Disjoint_Round_Robin data partitioner:
- */
-public class DRRSparkScheme extends DataPartitionSparkScheme {
-
- private static final long serialVersionUID = -3130831851505549672L;
-
- protected DRRSparkScheme() {
- // No-args constructor used for deserialization
- }
-
- @Override
- public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = nonShuffledPartition(rblkID, features);
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = nonShuffledPartition(rblkID, labels);
- return new Result(pfs, pls);
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRSparkScheme.java
deleted file mode 100644
index 51cc523..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DRSparkScheme.java
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-/**
- * Spark data partitioner Disjoint_Random:
- *
- * For the current row block, find all the shifted place for each row (WorkerID => (row block ID, matrix)
- */
-public class DRSparkScheme extends DataPartitionSparkScheme {
-
- private static final long serialVersionUID = -7655310624144544544L;
-
- protected DRSparkScheme() {
- // No-args constructor used for deserialization
- }
-
- @Override
- public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = partition(rblkID, features);
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = partition(rblkID, labels);
- return new Result(pfs, pls);
- }
-
- private List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> partition(int rblkID, MatrixBlock mb) {
- MatrixBlock partialPerm = _globalPerms.get(0).getBlock(rblkID, 1);
-
- // For each row, find out the shifted place
- return IntStream.range(0, mb.getNumRows()).mapToObj(r -> {
- MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1);
- long shiftedPosition = (long) partialPerm.getValue(r, 0);
-
- // Get the shifted block and position
- int shiftedBlkID = (int) (shiftedPosition / OptimizerUtils.DEFAULT_BLOCKSIZE + 1);
-
- MatrixBlock indicator = _workerIndicator.getBlock(shiftedBlkID, 1);
- int workerID = (int) indicator.getValue((int) shiftedPosition / OptimizerUtils.DEFAULT_BLOCKSIZE, 0);
- return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB));
- }).collect(Collectors.toList());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionSparkScheme.java
deleted file mode 100644
index 9875dd2..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionSparkScheme.java
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.io.Serializable;
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.LongStream;
-
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-public abstract class DataPartitionSparkScheme implements Serializable {
-
- protected final class Result {
- protected final List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pFeatures; // WorkerID => (rowID, matrix)
- protected final List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pLabels;
-
- protected Result(List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pFeatures, List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pLabels) {
- this.pFeatures = pFeatures;
- this.pLabels = pLabels;
- }
- }
-
- private static final long serialVersionUID = -3462829818083371171L;
-
- protected List<PartitionedBroadcast<MatrixBlock>> _globalPerms; // a list of global permutations
- protected PartitionedBroadcast<MatrixBlock> _workerIndicator; // a matrix indicating to which worker the given row belongs
-
- protected void setGlobalPermutation(List<PartitionedBroadcast<MatrixBlock>> gps) {
- _globalPerms = gps;
- }
-
- protected void setWorkerIndicator(PartitionedBroadcast<MatrixBlock> wi) {
- _workerIndicator = wi;
- }
-
- /**
- * Do non-reshuffled data partitioning according to worker indicator
- * @param rblkID row block ID
- * @param mb Matrix
- * @return list of tuple (workerID, (row block ID, matrix row))
- */
- protected List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> nonShuffledPartition(int rblkID, MatrixBlock mb) {
- MatrixBlock indicator = _workerIndicator.getBlock(rblkID, 1);
- return LongStream.range(0, mb.getNumRows()).mapToObj(r -> {
- int workerID = (int) indicator.getValue((int) r, 0);
- MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1);
- long shiftedPosition = r + (rblkID - 1) * OptimizerUtils.DEFAULT_BLOCKSIZE;
- return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB));
- }).collect(Collectors.toList());
- }
-
- public abstract Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels);
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkAggregator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkAggregator.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkAggregator.java
deleted file mode 100644
index 39b8adf..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkAggregator.java
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.io.Serializable;
-import java.util.LinkedList;
-
-import org.apache.spark.api.java.function.PairFunction;
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-public class DataPartitionerSparkAggregator implements PairFunction<Tuple2<Integer,LinkedList<Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>>>>, Integer, Tuple2<MatrixBlock, MatrixBlock>>, Serializable {
-
- private static final long serialVersionUID = -1245300852709085117L;
- private long _fcol;
- private long _lcol;
-
- public DataPartitionerSparkAggregator() {
-
- }
-
- public DataPartitionerSparkAggregator(long fcol, long lcol) {
- _fcol = fcol;
- _lcol = lcol;
- }
-
- /**
- * Row-wise combine the matrix
- * @param input workerID => ordered list [(rowBlockID, (features, labels))]
- * @return workerID => [(features, labels)]
- * @throws Exception Some exception
- */
- @Override
- public Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> call(Tuple2<Integer, LinkedList<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>> input) throws Exception {
- MatrixBlock fmb = new MatrixBlock(input._2.size(), (int) _fcol, false);
- MatrixBlock lmb = new MatrixBlock(input._2.size(), (int) _lcol, false);
-
- for (int i = 0; i < input._2.size(); i++) {
- MatrixBlock tmpFMB = input._2.get(i)._2._1;
- MatrixBlock tmpLMB = input._2.get(i)._2._2;
- // Row-wise aggregation
- fmb = fmb.leftIndexingOperations(tmpFMB, i, i, 0, (int) _fcol - 1, fmb, MatrixObject.UpdateType.INPLACE_PINNED);
- lmb = lmb.leftIndexingOperations(tmpLMB, i, i, 0, (int) _lcol - 1, lmb, MatrixObject.UpdateType.INPLACE_PINNED);
- }
- return new Tuple2<>(input._1, new Tuple2<>(fmb, lmb));
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkMapper.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkMapper.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkMapper.java
deleted file mode 100644
index 2a69986..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/DataPartitionerSparkMapper.java
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.io.Serializable;
-import java.util.Iterator;
-import java.util.LinkedList;
-import java.util.List;
-
-import org.apache.spark.api.java.function.PairFlatMapFunction;
-import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-public class DataPartitionerSparkMapper implements PairFlatMapFunction<Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>, Integer, Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>, Serializable {
-
- private static final long serialVersionUID = 1710721606050403296L;
- private int _workersNum;
-
- private SparkDataPartitioner _dp;
-
- protected DataPartitionerSparkMapper() {
- // No-args constructor used for deserialization
- }
-
- public DataPartitionerSparkMapper(Statement.PSScheme scheme, int workersNum, SparkExecutionContext sec, int numEntries) {
- _workersNum = workersNum;
- _dp = new SparkDataPartitioner(scheme, sec, numEntries, workersNum);
- }
-
- /**
- * Do data partitioning
- * @param input RowBlockID => (features, labels)
- * @return WorkerID => (rowBlockID, (single row features, single row labels))
- * @throws Exception Some exception
- */
- @Override
- public Iterator<Tuple2<Integer, Tuple2<Long, Tuple2<MatrixBlock, MatrixBlock>>>> call(Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>> input)
- throws Exception {
- List<Tuple2<Integer, Tuple2<Long,Tuple2<MatrixBlock,MatrixBlock>>>> partitions = new LinkedList<>();
- MatrixBlock features = input._2._1;
- MatrixBlock labels = input._2._2;
- DataPartitionSparkScheme.Result result = _dp.doPartitioning(_workersNum, features, labels, input._1);
- for (int i = 0; i < result.pFeatures.size(); i++) {
- Tuple2<Integer, Tuple2<Long, MatrixBlock>> ft = result.pFeatures.get(i);
- Tuple2<Integer, Tuple2<Long, MatrixBlock>> lt = result.pLabels.get(i);
- partitions.add(new Tuple2<>(ft._1, new Tuple2<>(ft._2._1, new Tuple2<>(ft._2._2, lt._2._2))));
- }
- return partitions.iterator();
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/ORSparkScheme.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/ORSparkScheme.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/ORSparkScheme.java
deleted file mode 100644
index 16ce516..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/ORSparkScheme.java
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-import scala.Tuple2;
-
-/**
- * Spark data partitioner Overlap_Reshuffle:
- *
- */
-public class ORSparkScheme extends DataPartitionSparkScheme {
-
- private static final long serialVersionUID = 6867567406403580311L;
-
- protected ORSparkScheme() {
- // No-args constructor used for deserialization
- }
-
- @Override
- public Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = partition(numWorkers, rblkID, features);
- List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = partition(numWorkers, rblkID, labels);
- return new Result(pfs, pls);
- }
-
- private List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> partition(int numWorkers, int rblkID, MatrixBlock mb) {
- return IntStream.range(0, numWorkers).mapToObj(i -> i).flatMap(workerID -> {
- MatrixBlock partialPerm = _globalPerms.get(workerID).getBlock(rblkID, 1);
- return IntStream.range(0, mb.getNumRows()).mapToObj(r -> {
- MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1);
- long shiftedPosition = (long) partialPerm.getValue(r, 0);
- return new Tuple2<>(workerID, new Tuple2<>(shiftedPosition, rowMB));
- });
- }).collect(Collectors.toList());
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkDataPartitioner.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkDataPartitioner.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkDataPartitioner.java
deleted file mode 100644
index 6883d0f..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkDataPartitioner.java
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import static org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils.SEED;
-
-import java.io.Serializable;
-import java.util.Arrays;
-import java.util.List;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
-import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.instructions.spark.data.PartitionedBroadcast;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.util.DataConverter;
-
-public class SparkDataPartitioner implements Serializable {
-
- private static final long serialVersionUID = 6841548626711057448L;
- private DataPartitionSparkScheme _scheme;
-
- protected SparkDataPartitioner(Statement.PSScheme scheme, SparkExecutionContext sec, int numEntries, int numWorkers) {
- switch (scheme) {
- case DISJOINT_CONTIGUOUS:
- _scheme = new DCSparkScheme();
- // Create the worker id indicator
- createDCIndicator(sec, numWorkers, numEntries);
- break;
- case DISJOINT_ROUND_ROBIN:
- _scheme = new DRRSparkScheme();
- // Create the worker id indicator
- createDRIndicator(sec, numWorkers, numEntries);
- break;
- case DISJOINT_RANDOM:
- _scheme = new DRSparkScheme();
- // Create the global permutation
- createGlobalPermutations(sec, numEntries, 1);
- // Create the worker id indicator
- createDCIndicator(sec, numWorkers, numEntries);
- break;
- case OVERLAP_RESHUFFLE:
- _scheme = new ORSparkScheme();
- // Create the global permutation seperately for each worker
- createGlobalPermutations(sec, numEntries, numWorkers);
- break;
- }
- }
-
- private void createDRIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) {
- double[] vector = IntStream.range(0, numEntries).mapToDouble(n -> n % numWorkers).toArray();
- MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
- _scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB)));
- }
-
- private void createDCIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) {
- double[] vector = new double[numEntries];
- int batchSize = (int) Math.ceil((double) numEntries / numWorkers);
- for (int i = 1; i < numWorkers; i++) {
- int begin = batchSize * i;
- int end = Math.min(begin + batchSize, numEntries);
- Arrays.fill(vector, begin, end, i);
- }
- MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
- _scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB)));
- }
-
- private void createGlobalPermutations(SparkExecutionContext sec, int numEntries, int numPerm) {
- List<PartitionedBroadcast<MatrixBlock>> perms = IntStream.range(0, numPerm).mapToObj(i -> {
- MatrixBlock perm = MatrixBlock.sampleOperations(numEntries, numEntries, false, SEED+i);
- // Create the source-target id vector from the permutation ranging from 1 to number of entries
- double[] vector = new double[numEntries];
- for (int j = 0; j < perm.getDenseBlockValues().length; j++) {
- vector[(int) perm.getDenseBlockValues()[j] - 1] = j;
- }
- MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
- return sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB));
- }).collect(Collectors.toList());
- _scheme.setGlobalPermutation(perms);
- }
-
- public DataPartitionSparkScheme.Result doPartitioning(int numWorkers, MatrixBlock features, MatrixBlock labels,
- long rowID) {
- // Set the rowID in order to get the according permutation
- return _scheme.doPartitioning(numWorkers, (int) rowID, features, labels);
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
deleted file mode 100644
index 9354025..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSBody.java
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-
-/**
- * Wrapper class containing all needed for launching spark remote worker
- */
-public class SparkPSBody {
-
- private ExecutionContext _ec;
-
- public SparkPSBody() {}
-
- public SparkPSBody(ExecutionContext ec) {
- _ec = ec;
- }
-
- public ExecutionContext getEc() {
- return _ec;
- }
-
- public void setEc(ExecutionContext ec) {
- this._ec = ec;
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
deleted file mode 100644
index 48a4883..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSProxy.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PULL;
-import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcObject.PUSH;
-
-import java.io.IOException;
-
-import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.util.LongAccumulator;
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse;
-import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
-import org.apache.sysml.runtime.instructions.cp.ListObject;
-
-public class SparkPSProxy extends ParamServer {
-
- private final TransportClient _client;
- private final long _rpcTimeout;
- private final LongAccumulator _aRPC;
-
- public SparkPSProxy(TransportClient client, long rpcTimeout, LongAccumulator aRPC) {
- super();
- _client = client;
- _rpcTimeout = rpcTimeout;
- _aRPC = aRPC;
- }
-
- private void accRpcRequestTime(Timing tRpc) {
- if (DMLScript.STATISTICS)
- _aRPC.add((long) tRpc.stop());
- }
-
- @Override
- public void push(int workerID, ListObject value) {
- Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
- PSRpcResponse response;
- try {
- response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PUSH, workerID, value).serialize(), _rpcTimeout));
- } catch (IOException e) {
- throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients.", workerID), e);
- }
- accRpcRequestTime(tRpc);
- if (!response.isSuccessful()) {
- throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to push gradients. \n%s", workerID, response.getErrorMessage()));
- }
- }
-
- @Override
- public ListObject pull(int workerID) {
- Timing tRpc = DMLScript.STATISTICS ? new Timing(true) : null;
- PSRpcResponse response;
- try {
- response = new PSRpcResponse(_client.sendRpcSync(new PSRpcCall(PULL, workerID, null).serialize(), _rpcTimeout));
- } catch (IOException e) {
- throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models.", workerID), e);
- }
- accRpcRequestTime(tRpc);
- if (!response.isSuccessful()) {
- throw new DMLRuntimeException(String.format("SparkPSProxy: spark worker_%d failed to pull models. \n%s", workerID, response.getErrorMessage()));
- }
- return response.getResultModel();
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
deleted file mode 100644
index cb3e729..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/SparkPSWorker.java
+++ /dev/null
@@ -1,168 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark;
-
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.function.VoidFunction;
-import org.apache.spark.util.LongAccumulator;
-import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.codegen.CodegenUtils;
-import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory;
-import org.apache.sysml.runtime.controlprogram.parfor.RemoteParForUtils;
-import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.util.ProgramConverter;
-
-import scala.Tuple2;
-
-public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>> {
-
- private static final long serialVersionUID = -8674739573419648732L;
-
- private final String _program;
- private final HashMap<String, byte[]> _clsMap;
- private final SparkConf _conf;
- private final int _port; // rpc port
- private final String _aggFunc;
- private final LongAccumulator _aSetup; // accumulator for setup time
- private final LongAccumulator _aWorker; // accumulator for worker number
- private final LongAccumulator _aUpdate; // accumulator for model update
- private final LongAccumulator _aIndex; // accumulator for batch indexing
- private final LongAccumulator _aGrad; // accumulator for gradients computing
- private final LongAccumulator _aRPC; // accumulator for rpc request
- private final LongAccumulator _nBatches; //number of executed batches
- private final LongAccumulator _nEpochs; //number of executed epoches
-
- public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches, LongAccumulator aEpochs) {
- _updFunc = updFunc;
- _aggFunc = aggFunc;
- _freq = freq;
- _epochs = epochs;
- _batchSize = batchSize;
- _program = program;
- _clsMap = clsMap;
- _conf = conf;
- _port = port;
- _aSetup = aSetup;
- _aWorker = aWorker;
- _aUpdate = aUpdate;
- _aIndex = aIndex;
- _aGrad = aGrad;
- _aRPC = aRPC;
- _nBatches = aBatches;
- _nEpochs = aEpochs;
- }
-
- @Override
- public String getWorkerName() {
- return String.format("Spark worker_%d", _workerID);
- }
-
- @Override
- public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception {
- Timing tSetup = new Timing(true);
- configureWorker(input);
- accSetupTime(tSetup);
-
- call(); // Launch the worker
- }
-
- private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws IOException {
- _workerID = input._1;
-
- // Initialize codegen class cache (before program parsing)
- for (Map.Entry<String, byte[]> e : _clsMap.entrySet()) {
- CodegenUtils.getClassSync(e.getKey(), e.getValue());
- }
-
- // Deserialize the body to initialize the execution context
- SparkPSBody body = ProgramConverter.parseSparkPSBody(_program, _workerID);
- _ec = body.getEc();
-
- // Initialize the buffer pool and register it in the jvm shutdown hook in order to be cleanuped at the end
- RemoteParForUtils.setupBufferPool(_workerID);
-
- // Get some configurations
- long rpcTimeout = _conf.contains("spark.rpc.askTimeout") ?
- _conf.getTimeAsMs("spark.rpc.askTimeout") :
- _conf.getTimeAsMs("spark.network.timeout", "120s");
- String host = _conf.get("spark.driver.host");
-
- // Create the ps proxy
- _ps = PSRpcFactory.createSparkPSProxy(_conf, host, _port, rpcTimeout, _aRPC);
-
- // Initialize the update function
- setupUpdateFunction(_updFunc, _ec);
-
- // Initialize the agg function
- _ps.setupAggFunc(_ec, _aggFunc);
-
- // Lazy initialize the matrix of features and labels
- setFeatures(ParamservUtils.newMatrixObject(input._2._1));
- setLabels(ParamservUtils.newMatrixObject(input._2._2));
- _features.enableCleanup(false);
- _labels.enableCleanup(false);
- }
-
-
- @Override
- protected void incWorkerNumber() {
- _aWorker.add(1);
- }
-
- @Override
- protected void accLocalModelUpdateTime(Timing time) {
- if( time != null )
- _aUpdate.add((long) time.stop());
- }
-
- @Override
- protected void accBatchIndexingTime(Timing time) {
- if( time != null )
- _aIndex.add((long) time.stop());
- }
-
- @Override
- protected void accGradientComputeTime(Timing time) {
- if( time != null )
- _aGrad.add((long) time.stop());
- }
-
- @Override
- protected void accNumEpochs(int n) {
- _nEpochs.add(n);
- }
-
- @Override
- protected void accNumBatches(int n) {
- _nBatches.add(n);
- }
-
- private void accSetupTime(Timing time) {
- if( time != null )
- _aSetup.add((long) time.stop());
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
deleted file mode 100644
index a33fda2..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcCall.java
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.caching.CacheDataOutput;
-import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.runtime.util.ByteBufferDataInput;
-
-public class PSRpcCall extends PSRpcObject {
-
- private int _method;
- private int _workerID;
- private ListObject _data;
-
- public PSRpcCall(int method, int workerID, ListObject data) {
- _method = method;
- _workerID = workerID;
- _data = data;
- }
-
- public PSRpcCall(ByteBuffer buffer) throws IOException {
- deserialize(buffer);
- }
-
- public int getMethod() {
- return _method;
- }
-
- public int getWorkerID() {
- return _workerID;
- }
-
- public ListObject getData() {
- return _data;
- }
-
- public void deserialize(ByteBuffer buffer) throws IOException {
- ByteBufferDataInput dis = new ByteBufferDataInput(buffer);
- _method = dis.readInt();
- validateMethod(_method);
- _workerID = dis.readInt();
- if (dis.available() > 1)
- _data = readAndDeserialize(dis);
- }
-
- public ByteBuffer serialize() throws IOException {
- int len = 8 + getExactSerializedSize(_data);
- CacheDataOutput dos = new CacheDataOutput(len);
- dos.writeInt(_method);
- dos.writeInt(_workerID);
- if (_data != null)
- serializeAndWriteListObject(_data, dos);
- return ByteBuffer.wrap(dos.getBytes());
- }
-
- private void validateMethod(int method) {
- switch (method) {
- case PUSH:
- case PULL:
- break;
- default:
- throw new DMLRuntimeException("PSRpcCall: only support rpc method 'push' or 'pull'");
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
deleted file mode 100644
index 5e76d23..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcFactory.java
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
-
-import java.io.IOException;
-import java.util.Collections;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.network.TransportContext;
-import org.apache.spark.network.netty.SparkTransportConf;
-import org.apache.spark.network.server.TransportServer;
-import org.apache.spark.network.util.TransportConf;
-import org.apache.spark.util.LongAccumulator;
-import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSProxy;
-
-public class PSRpcFactory {
-
- private static final String MODULE_NAME = "ps";
-
- private static TransportContext createTransportContext(SparkConf conf, LocalParamServer ps) {
- TransportConf tc = SparkTransportConf.fromSparkConf(conf, MODULE_NAME, 0);
- PSRpcHandler handler = new PSRpcHandler(ps);
- return new TransportContext(tc, handler);
- }
-
- /**
- * Create and start the server
- * @return server
- */
- public static TransportServer createServer(SparkConf conf, LocalParamServer ps, String host) {
- TransportContext context = createTransportContext(conf, ps);
- return context.createServer(host, 0, Collections.emptyList()); // bind rpc to an ephemeral port
- }
-
- public static SparkPSProxy createSparkPSProxy(SparkConf conf, String host, int port, long rpcTimeout, LongAccumulator aRPC) throws IOException {
- TransportContext context = createTransportContext(conf, new LocalParamServer());
- return new SparkPSProxy(context.createClientFactory().createClient(host, port), rpcTimeout, aRPC);
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
deleted file mode 100644
index a2c311e..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcHandler.java
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
-
-import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PULL;
-import static org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcCall.PUSH;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-
-import org.apache.commons.lang.exception.ExceptionUtils;
-import org.apache.spark.network.client.RpcResponseCallback;
-import org.apache.spark.network.client.TransportClient;
-import org.apache.spark.network.server.OneForOneStreamManager;
-import org.apache.spark.network.server.RpcHandler;
-import org.apache.spark.network.server.StreamManager;
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcResponse.Type;
-import org.apache.sysml.runtime.instructions.cp.ListObject;
-
-public final class PSRpcHandler extends RpcHandler {
-
- private LocalParamServer _server;
-
- protected PSRpcHandler(LocalParamServer server) {
- _server = server;
- }
-
- @Override
- public void receive(TransportClient client, ByteBuffer buffer, RpcResponseCallback callback) {
- PSRpcCall call;
- try {
- call = new PSRpcCall(buffer);
- } catch (IOException e) {
- throw new DMLRuntimeException("PSRpcHandler: some error occurred when deserializing the rpc call.", e);
- }
- PSRpcResponse response = null;
- switch (call.getMethod()) {
- case PUSH:
- try {
- _server.push(call.getWorkerID(), call.getData());
- response = new PSRpcResponse(Type.SUCCESS_EMPTY);
- } catch (DMLRuntimeException exception) {
- response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception));
- } finally {
- try {
- callback.onSuccess(response.serialize());
- } catch (IOException e) {
- throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e);
- }
- }
- break;
- case PULL:
- ListObject data;
- try {
- data = _server.pull(call.getWorkerID());
- response = new PSRpcResponse(Type.SUCCESS, data);
- } catch (DMLRuntimeException exception) {
- response = new PSRpcResponse(Type.ERROR, ExceptionUtils.getFullStackTrace(exception));
- } finally {
- try {
- callback.onSuccess(response.serialize());
- } catch (IOException e) {
- throw new DMLRuntimeException("PSRpcHandler: some error occrred when wrapping the rpc response.", e);
- }
- }
- break;
- default:
- throw new DMLRuntimeException(String.format("Does not support the rpc call for method %s", call.getMethod()));
- }
- }
-
- @Override
- public StreamManager getStreamManager() {
- return new OneForOneStreamManager();
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
deleted file mode 100644
index 816cefd..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcObject.java
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
-
-import java.io.DataInput;
-import java.io.DataOutput;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.List;
-
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.instructions.cp.Data;
-import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.runtime.io.IOUtilFunctions;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-
-public abstract class PSRpcObject {
-
- public static final int PUSH = 1;
- public static final int PULL = 2;
-
- public abstract void deserialize(ByteBuffer buffer) throws IOException;
-
- public abstract ByteBuffer serialize() throws IOException;
-
- /**
- * Deep serialize and write of a list object (currently only support list containing matrices)
- * @param lo a list object containing only matrices
- * @param output output data to write to
- */
- protected void serializeAndWriteListObject(ListObject lo, DataOutput output) throws IOException {
- validateListObject(lo);
- output.writeInt(lo.getLength()); //write list length
- output.writeBoolean(lo.isNamedList()); //write list named
- for (int i = 0; i < lo.getLength(); i++) {
- if (lo.isNamedList())
- output.writeUTF(lo.getName(i)); //write name
- ((MatrixObject) lo.getData().get(i))
- .acquireReadAndRelease().write(output); //write matrix
- }
- // Cleanup the list object
- // because it is transferred to remote worker in binary format
- ParamservUtils.cleanupListObject(lo);
- }
-
- protected ListObject readAndDeserialize(DataInput input) throws IOException {
- int listLen = input.readInt();
- List<Data> data = new ArrayList<>();
- List<String> names = input.readBoolean() ?
- new ArrayList<>() : null;
- for(int i=0; i<listLen; i++) {
- if( names != null )
- names.add(input.readUTF());
- MatrixBlock mb = new MatrixBlock();
- mb.readFields(input);
- data.add(ParamservUtils.newMatrixObject(mb, false));
- }
- return new ListObject(data, names);
- }
-
- /**
- * Get serialization size of a list object
- * (scheme: size|name|size|matrix)
- * @param lo list object
- * @return serialization size
- */
- protected int getExactSerializedSize(ListObject lo) {
- if( lo == null ) return 0;
- long result = 4 + 1; // list length and of named
- if (lo.isNamedList()) //size for names incl length
- result += lo.getNames().stream().mapToLong(s -> IOUtilFunctions.getUTFSize(s)).sum();
- result += lo.getData().stream().mapToLong(d ->
- ((MatrixObject)d).acquireReadAndRelease().getExactSizeOnDisk()).sum();
- if( result > Integer.MAX_VALUE )
- throw new DMLRuntimeException("Serialized size ("+result+") larger than Integer.MAX_VALUE.");
- return (int) result;
- }
-
- private void validateListObject(ListObject lo) {
- for (Data d : lo.getData()) {
- if (!(d instanceof MatrixObject)) {
- throw new DMLRuntimeException(String.format("Paramserv func:"
- + " Unsupported deep serialize of %s, which is not matrix.", d.getDebugName()));
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
deleted file mode 100644
index 010481e..0000000
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/spark/rpc/PSRpcResponse.java
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-
-import org.apache.sysml.runtime.util.ByteBufferDataInput;
-import org.apache.sysml.runtime.controlprogram.caching.CacheDataOutput;
-import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.runtime.io.IOUtilFunctions;
-
-public class PSRpcResponse extends PSRpcObject {
- public enum Type {
- SUCCESS,
- SUCCESS_EMPTY,
- ERROR,
- }
-
- private Type _status;
- private Object _data; // Could be list object or exception
-
- public PSRpcResponse(ByteBuffer buffer) throws IOException {
- deserialize(buffer);
- }
-
- public PSRpcResponse(Type status) {
- this(status, null);
- }
-
- public PSRpcResponse(Type status, Object data) {
- _status = status;
- _data = data;
- if( _status == Type.SUCCESS && data == null )
- _status = Type.SUCCESS_EMPTY;
- }
-
- public boolean isSuccessful() {
- return _status != Type.ERROR;
- }
-
- public String getErrorMessage() {
- return (String) _data;
- }
-
- public ListObject getResultModel() {
- return (ListObject) _data;
- }
-
- @Override
- public void deserialize(ByteBuffer buffer) throws IOException {
- ByteBufferDataInput dis = new ByteBufferDataInput(buffer);
- _status = Type.values()[dis.readInt()];
- switch (_status) {
- case SUCCESS:
- _data = readAndDeserialize(dis);
- break;
- case SUCCESS_EMPTY:
- break;
- case ERROR:
- _data = dis.readUTF();
- break;
- }
- }
-
- @Override
- public ByteBuffer serialize() throws IOException {
- int len = 4 + (_status==Type.SUCCESS ? getExactSerializedSize((ListObject)_data) :
- _status==Type.SUCCESS_EMPTY ? 0 : IOUtilFunctions.getUTFSize((String)_data));
- CacheDataOutput dos = new CacheDataOutput(len);
- dos.writeInt(_status.ordinal());
- switch (_status) {
- case SUCCESS:
- serializeAndWriteListObject((ListObject) _data, dos);
- break;
- case SUCCESS_EMPTY:
- break;
- case ERROR:
- dos.writeUTF(_data.toString());
- break;
- }
- return ByteBuffer.wrap(dos.getBytes());
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
index 6220bb6..83ec3f7 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
@@ -65,15 +65,15 @@ import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionScheme;
-import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitioner;
+import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
+import org.apache.sysml.runtime.controlprogram.paramserv.dp.LocalDataPartitioner;
import org.apache.sysml.runtime.controlprogram.paramserv.LocalPSWorker;
import org.apache.sysml.runtime.controlprogram.paramserv.LocalParamServer;
import org.apache.sysml.runtime.controlprogram.paramserv.ParamServer;
import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSBody;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSWorker;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.rpc.PSRpcFactory;
+import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSBody;
+import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSWorker;
+import org.apache.sysml.runtime.controlprogram.paramserv.rpc.PSRpcFactory;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.matrix.operators.Operator;
@@ -350,7 +350,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
switch (mode) {
case LOCAL:
case REMOTE_SPARK:
- return new LocalParamServer(model, aggFunc, updateType, ec, workerNum);
+ return LocalParamServer.create(model, aggFunc, updateType, ec, workerNum);
default:
throw new DMLRuntimeException("Unsupported parameter server: "+mode.name());
}
@@ -379,9 +379,7 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc
private void partitionLocally(PSScheme scheme, ExecutionContext ec, List<LocalPSWorker> workers) {
MatrixObject features = ec.getMatrixObject(getParam(PS_FEATURES));
MatrixObject labels = ec.getMatrixObject(getParam(PS_LABELS));
- DataPartitionScheme.Result result = new DataPartitioner(scheme).doPartitioning(workers.size(), features.acquireRead(), labels.acquireRead());
- features.release();
- labels.release();
+ DataPartitionLocalScheme.Result result = new LocalDataPartitioner(scheme).doPartitioning(workers.size(), features.acquireReadAndRelease(), labels.acquireReadAndRelease());
List<MatrixObject> pfs = result.pFeatures;
List<MatrixObject> pls = result.pLabels;
if (pfs.size() < workers.size()) {
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
index fc9d9b4..21e6bd3 100644
--- a/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysml/runtime/util/ProgramConverter.java
@@ -69,7 +69,7 @@ import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
-import org.apache.sysml.runtime.controlprogram.paramserv.spark.SparkPSBody;
+import org.apache.sysml.runtime.controlprogram.paramserv.SparkPSBody;
import org.apache.sysml.runtime.controlprogram.parfor.ParForBody;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.instructions.CPInstructionParser;
http://git-wip-us.apache.org/repos/asf/systemml/blob/382f847d/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java
index 0092aed..2f39c91 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/paramserv/BaseDataPartitionerTest.java
@@ -22,8 +22,8 @@ package org.apache.sysml.test.integration.functions.paramserv;
import java.util.stream.IntStream;
import org.apache.sysml.parser.Statement;
-import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitionScheme;
-import org.apache.sysml.runtime.controlprogram.paramserv.DataPartitioner;
+import org.apache.sysml.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
+import org.apache.sysml.runtime.controlprogram.paramserv.dp.LocalDataPartitioner;
import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.util.DataConverter;
@@ -54,26 +54,26 @@ public abstract class BaseDataPartitionerTest {
return IntStream.range(from, to).mapToDouble(i -> (double) i).toArray();
}
- protected DataPartitionScheme.Result launchLocalDataPartitionerDC() {
- DataPartitioner dp = new DataPartitioner(Statement.PSScheme.DISJOINT_CONTIGUOUS);
+ protected DataPartitionLocalScheme.Result launchLocalDataPartitionerDC() {
+ LocalDataPartitioner dp = new LocalDataPartitioner(Statement.PSScheme.DISJOINT_CONTIGUOUS);
MatrixBlock[] mbs = generateData();
return dp.doPartitioning(WORKER_NUM, mbs[0], mbs[1]);
}
- protected DataPartitionScheme.Result launchLocalDataPartitionerDR(MatrixBlock[] mbs) {
+ protected DataPartitionLocalScheme.Result launchLocalDataPartitionerDR(MatrixBlock[] mbs) {
ParamservUtils.SEED = System.nanoTime();
- DataPartitioner dp = new DataPartitioner(Statement.PSScheme.DISJOINT_RANDOM);
+ LocalDataPartitioner dp = new LocalDataPartitioner(Statement.PSScheme.DISJOINT_RANDOM);
return dp.doPartitioning(WORKER_NUM, mbs[0], mbs[1]);
}
- protected DataPartitionScheme.Result launchLocalDataPartitionerDRR() {
- DataPartitioner dp = new DataPartitioner(Statement.PSScheme.DISJOINT_ROUND_ROBIN);
+ protected DataPartitionLocalScheme.Result launchLocalDataPartitionerDRR() {
+ LocalDataPartitioner dp = new LocalDataPartitioner(Statement.PSScheme.DISJOINT_ROUND_ROBIN);
MatrixBlock[] mbs = generateData();
return dp.doPartitioning(WORKER_NUM, mbs[0], mbs[1]);
}
- protected DataPartitionScheme.Result launchLocalDataPartitionerOR() {
- DataPartitioner dp = new DataPartitioner(Statement.PSScheme.OVERLAP_RESHUFFLE);
+ protected DataPartitionLocalScheme.Result launchLocalDataPartitionerOR() {
+ LocalDataPartitioner dp = new LocalDataPartitioner(Statement.PSScheme.OVERLAP_RESHUFFLE);
MatrixBlock[] mbs = generateData();
return dp.doPartitioning(WORKER_NUM, mbs[0], mbs[1]);
}