You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/02/25 07:56:44 UTC
[flink-ml] branch master updated: [FLINK-26263] (followup) Check data size in LogisticRegression
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new bd6d67f [FLINK-26263] (followup) Check data size in LogisticRegression
bd6d67f is described below
commit bd6d67f78fe5341ec992d72b633b26e2def7783a
Author: zhangzp <zh...@gmail.com>
AuthorDate: Fri Feb 25 11:18:46 2022 +0800
[FLINK-26263] (followup) Check data size in LogisticRegression
This closes #66.
---
.../flink/ml/common/datastream/AllReduceImpl.java | 4 +--
.../ml/common/datastream/DataStreamUtils.java | 6 ++--
.../ml/common/datastream/AllReduceImplTest.java | 17 +++++++++
.../logisticregression/LogisticRegression.java | 41 ++++++++++------------
4 files changed, 41 insertions(+), 27 deletions(-)
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
index b8571a0..760b5db 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
@@ -37,7 +37,7 @@ import java.util.HashMap;
import java.util.Map;
/**
- * Applies all-reduce on a data stream where each partition contains only one double array.
+ * Applies all-reduce on a data stream where each partition contains up to one double array.
*
* <p>AllReduce is a communication primitive widely used in MPI. In this implementation, all workers
* do reduce on a partition of the whole data and they all get the final reduce result. In detail,
@@ -55,7 +55,7 @@ class AllReduceImpl {
/**
* Applies allReduceSum on the input data stream. The input data stream is supposed to contain
- * one double array in each worker. The result data stream has the same parallelism as the
+ * up to one double array in each worker. The result data stream has the same parallelism as the
* input, where each worker contains one double array that sums all of the double arrays in the
* input data stream.
*
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
index 4cad85e..58eae62 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -35,9 +35,9 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
public class DataStreamUtils {
/**
* Applies allReduceSum on the input data stream. The input data stream is supposed to contain
- * one double array in each partition. The result data stream has the same parallelism as the
- * input, where each partition contains one double array that sums all of the double arrays in
- * the input data stream.
+ * up to one double array in each partition. The result data stream has the same parallelism as
+ * the input, where each partition contains one double array that sums all of the double arrays
+ * in the input data stream.
*
* <p>Note that we throw exception when one of the following two cases happen:
* <li>There exists one partition that contains more than one double array.
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java
index 1ee0201..6b0136b 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.ml.common.datastream;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
@@ -33,8 +34,10 @@ import org.junit.runners.Parameterized;
import java.util.Arrays;
import java.util.Collection;
+import java.util.Iterator;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.fail;
/** Tests the {@link AllReduceImpl}. */
@@ -161,5 +164,19 @@ public class AllReduceImplTest {
e.getCause().getCause().getMessage());
}
}
+
+ @Test
+ public void testAllReduceWithEmptyInput() throws Exception {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(parallelism);
+ DataStream<double[]> elements =
+ env.fromParallelCollection(
+ new NumberSequenceIterator(1L, parallelism),
+ BasicTypeInfo.LONG_TYPE_INFO)
+ .flatMap((FlatMapFunction<Long, double[]>) (value, out) -> {})
+ .returns(PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO);
+ Iterator<double[]> result = DataStreamUtils.allReduceSum(elements).executeAndCollect();
+ assertFalse(result.hasNext());
+ }
}
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index 587f225..08d9c78 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -129,10 +129,11 @@ public class LogisticRegression
return new LabeledPointWithWeight(features, label, weight);
});
DataStream<double[]> initModelData =
- trainData.transform(
- "genInitModelData",
- PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
- new GenInitModelData());
+ trainData
+ .transform("getModelDim", BasicTypeInfo.INT_TYPE_INFO, new GetModelDim())
+ .setParallelism(1)
+ .broadcast()
+ .map(double[]::new);
DataStream<LogisticRegressionModelData> modelData = train(trainData, initModelData);
LogisticRegressionModel model =
@@ -141,12 +142,9 @@ public class LogisticRegression
return model;
}
- /**
- * Generates initialized model data. Note that the parallelism of model data is same as the
- * input train data, not one.
- */
- private static class GenInitModelData extends AbstractStreamOperator<double[]>
- implements OneInputStreamOperator<LabeledPointWithWeight, double[]>, BoundedOneInput {
+ /** Gets the dimension of the model data. */
+ private static class GetModelDim extends AbstractStreamOperator<Integer>
+ implements OneInputStreamOperator<LabeledPointWithWeight, Integer>, BoundedOneInput {
private int dim = 0;
@@ -154,7 +152,7 @@ public class LogisticRegression
@Override
public void endInput() {
- output.collect(new StreamRecord<>(new double[dim]));
+ output.collect(new StreamRecord<>(dim));
}
@Override
@@ -363,9 +361,6 @@ public class LogisticRegression
public void onEpochWatermarkIncremented(
int epochWatermark, Context context, Collector<double[]> collector)
throws Exception {
- if (!trainDataState.get().iterator().hasNext()) {
- return;
- }
if (epochWatermark == 0) {
coefficient = new DenseVector(feedbackBuffer);
coefficientDim = coefficient.size();
@@ -378,14 +373,16 @@ public class LogisticRegression
if (trainData == null) {
trainData = IteratorUtils.toList(trainDataState.get().iterator());
}
- miniBatchData = getMiniBatchData(trainData, localBatchSize);
- Tuple2<Double, Double> weightSumAndLossSum =
- logisticGradient.computeLoss(miniBatchData, coefficient);
- logisticGradient.computeGradient(miniBatchData, coefficient, gradient);
- System.arraycopy(gradient.values, 0, feedbackBuffer, 0, gradient.size());
- feedbackBuffer[coefficientDim] = weightSumAndLossSum.f0;
- feedbackBuffer[coefficientDim + 1] = weightSumAndLossSum.f1;
- collector.collect(feedbackBuffer);
+ if (trainData.size() > 0) {
+ miniBatchData = getMiniBatchData(trainData, localBatchSize);
+ Tuple2<Double, Double> weightSumAndLossSum =
+ logisticGradient.computeLoss(miniBatchData, coefficient);
+ logisticGradient.computeGradient(miniBatchData, coefficient, gradient);
+ System.arraycopy(gradient.values, 0, feedbackBuffer, 0, gradient.size());
+ feedbackBuffer[coefficientDim] = weightSumAndLossSum.f0;
+ feedbackBuffer[coefficientDim + 1] = weightSumAndLossSum.f1;
+ collector.collect(feedbackBuffer);
+ }
}
@Override