You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/02/25 07:56:44 UTC

[flink-ml] branch master updated: [FLINK-26263] (followup) Check data size in LogisticRegression

This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new bd6d67f  [FLINK-26263] (followup) Check data size in LogisticRegression
bd6d67f is described below

commit bd6d67f78fe5341ec992d72b633b26e2def7783a
Author: zhangzp <zh...@gmail.com>
AuthorDate: Fri Feb 25 11:18:46 2022 +0800

    [FLINK-26263] (followup) Check data size in LogisticRegression
    
    This closes #66.
---
 .../flink/ml/common/datastream/AllReduceImpl.java  |  4 +--
 .../ml/common/datastream/DataStreamUtils.java      |  6 ++--
 .../ml/common/datastream/AllReduceImplTest.java    | 17 +++++++++
 .../logisticregression/LogisticRegression.java     | 41 ++++++++++------------
 4 files changed, 41 insertions(+), 27 deletions(-)

diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
index b8571a0..760b5db 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
@@ -37,7 +37,7 @@ import java.util.HashMap;
 import java.util.Map;
 
 /**
- * Applies all-reduce on a data stream where each partition contains only one double array.
+ * Applies all-reduce on a data stream where each partition contains up to one double array.
  *
  * <p>AllReduce is a communication primitive widely used in MPI. In this implementation, all workers
  * do reduce on a partition of the whole data and they all get the final reduce result. In detail,
@@ -55,7 +55,7 @@ class AllReduceImpl {
 
     /**
      * Applies allReduceSum on the input data stream. The input data stream is supposed to contain
-     * one double array in each worker. The result data stream has the same parallelism as the
+     * up to one double array in each worker. The result data stream has the same parallelism as the
      * input, where each worker contains one double array that sums all of the double arrays in the
      * input data stream.
      *
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
index 4cad85e..58eae62 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -35,9 +35,9 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 public class DataStreamUtils {
     /**
      * Applies allReduceSum on the input data stream. The input data stream is supposed to contain
-     * one double array in each partition. The result data stream has the same parallelism as the
-     * input, where each partition contains one double array that sums all of the double arrays in
-     * the input data stream.
+     * up to one double array in each partition. The result data stream has the same parallelism as
+     * the input, where each partition contains one double array that sums all of the double arrays
+     * in the input data stream.
      *
      * <p>Note that we throw exception when one of the following two cases happen:
      * <li>There exists one partition that contains more than one double array.
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java
index 1ee0201..6b0136b 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.ml.common.datastream;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.SinkFunction;
@@ -33,8 +34,10 @@ import org.junit.runners.Parameterized;
 
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.Iterator;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.fail;
 
 /** Tests the {@link AllReduceImpl}. */
@@ -161,5 +164,19 @@ public class AllReduceImplTest {
                         e.getCause().getCause().getMessage());
             }
         }
+
+        @Test
+        public void testAllReduceWithEmptyInput() throws Exception {
+            StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+            env.setParallelism(parallelism);
+            DataStream<double[]> elements =
+                    env.fromParallelCollection(
+                                    new NumberSequenceIterator(1L, parallelism),
+                                    BasicTypeInfo.LONG_TYPE_INFO)
+                            .flatMap((FlatMapFunction<Long, double[]>) (value, out) -> {})
+                            .returns(PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO);
+            Iterator<double[]> result = DataStreamUtils.allReduceSum(elements).executeAndCollect();
+            assertFalse(result.hasNext());
+        }
     }
 }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
index 587f225..08d9c78 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java
@@ -129,10 +129,11 @@ public class LogisticRegression
                                     return new LabeledPointWithWeight(features, label, weight);
                                 });
         DataStream<double[]> initModelData =
-                trainData.transform(
-                        "genInitModelData",
-                        PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
-                        new GenInitModelData());
+                trainData
+                        .transform("getModelDim", BasicTypeInfo.INT_TYPE_INFO, new GetModelDim())
+                        .setParallelism(1)
+                        .broadcast()
+                        .map(double[]::new);
 
         DataStream<LogisticRegressionModelData> modelData = train(trainData, initModelData);
         LogisticRegressionModel model =
@@ -141,12 +142,9 @@ public class LogisticRegression
         return model;
     }
 
-    /**
-     * Generates initialized model data. Note that the parallelism of model data is same as the
-     * input train data, not one.
-     */
-    private static class GenInitModelData extends AbstractStreamOperator<double[]>
-            implements OneInputStreamOperator<LabeledPointWithWeight, double[]>, BoundedOneInput {
+    /** Gets the dimension of the model data. */
+    private static class GetModelDim extends AbstractStreamOperator<Integer>
+            implements OneInputStreamOperator<LabeledPointWithWeight, Integer>, BoundedOneInput {
 
         private int dim = 0;
 
@@ -154,7 +152,7 @@ public class LogisticRegression
 
         @Override
         public void endInput() {
-            output.collect(new StreamRecord<>(new double[dim]));
+            output.collect(new StreamRecord<>(dim));
         }
 
         @Override
@@ -363,9 +361,6 @@ public class LogisticRegression
         public void onEpochWatermarkIncremented(
                 int epochWatermark, Context context, Collector<double[]> collector)
                 throws Exception {
-            if (!trainDataState.get().iterator().hasNext()) {
-                return;
-            }
             if (epochWatermark == 0) {
                 coefficient = new DenseVector(feedbackBuffer);
                 coefficientDim = coefficient.size();
@@ -378,14 +373,16 @@ public class LogisticRegression
             if (trainData == null) {
                 trainData = IteratorUtils.toList(trainDataState.get().iterator());
             }
-            miniBatchData = getMiniBatchData(trainData, localBatchSize);
-            Tuple2<Double, Double> weightSumAndLossSum =
-                    logisticGradient.computeLoss(miniBatchData, coefficient);
-            logisticGradient.computeGradient(miniBatchData, coefficient, gradient);
-            System.arraycopy(gradient.values, 0, feedbackBuffer, 0, gradient.size());
-            feedbackBuffer[coefficientDim] = weightSumAndLossSum.f0;
-            feedbackBuffer[coefficientDim + 1] = weightSumAndLossSum.f1;
-            collector.collect(feedbackBuffer);
+            if (trainData.size() > 0) {
+                miniBatchData = getMiniBatchData(trainData, localBatchSize);
+                Tuple2<Double, Double> weightSumAndLossSum =
+                        logisticGradient.computeLoss(miniBatchData, coefficient);
+                logisticGradient.computeGradient(miniBatchData, coefficient, gradient);
+                System.arraycopy(gradient.values, 0, feedbackBuffer, 0, gradient.size());
+                feedbackBuffer[coefficientDim] = weightSumAndLossSum.f0;
+                feedbackBuffer[coefficientDim + 1] = weightSumAndLossSum.f1;
+                collector.collect(feedbackBuffer);
+            }
         }
 
         @Override