You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/03/03 07:38:39 UTC

[GitHub] [flink-ml] yunfengzhou-hub opened a new pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

yunfengzhou-hub opened a new pull request #70:
URL: https://github.com/apache/flink-ml/pull/70


   ## What is the purpose of the change
   This PR adds Estimator and Transformer for the Streaming KMeans operator.
   
   Compared with the existing KMeans operator, Streaming KMeans allows to train KMeans model continuously from an unbounded train data stream. The corresponding Model operator also supports updating model data dynamically from a DataStream.
   
   Besides, this PR also adds simple infrastructures needed to test online algorithms, which allows to control the order to consume train data and predict data.
   
   ## Brief change log
   - Adds `StreamingKMeans`, `StreamingKMeansModel` and `StreamingKMeansParams` class to support Streaming KMeans algorithm. Also adds `StreamingKMeansTest` class to test these classes.
   - Adds `HasBatchStrategy` and `HasDecayFactor` interfaces to represent corresponding parameters for online algorithms.
   - Adds `MockBlockingQueueSinkFunction`, `MockBlockingQueueSourceFunction` and `TestBlockingQueueManager` to control the stream's velocity in online algorithm's test cases.
   
   ## Does this pull request potentially affect one of the following parts:
   - Dependencies (does it add or upgrade a dependency): (no)
   - The public API, i.e., is any changed class annotated with @public(Evolving): (no)
   - Does this pull request introduce a new feature? (yes)
   - If yes, how is the feature documented? (Java doc)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r835994436



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -177,11 +177,20 @@ public void testFewerDistinctPointsThanCluster() {
         KMeans kmeans = new KMeans().setK(2);
         KMeansModel model = kmeans.fit(input);
         Table output = model.transform(input)[0];
-        List<Set<DenseVector>> expectedGroups =
-                Collections.singletonList(Collections.singleton(Vectors.dense(0.0, 0.1)));
-        List<Set<DenseVector>> actualGroups =
-                executeAndCollect(output, kmeans.getFeaturesCol(), kmeans.getPredictionCol());
-        assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+
+        try {

Review comment:
       I think the current definition of K and the expected behavior is contradict. If we want to keep the same behavior, I think we should adopt one of the followings:
   
   - Change `K`'s description from `The number of clusters to create` to `The max number of clusters to create`
   - If there are fewer distinct points than clusters, the training process would still create K clusters, but some of them are identical.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r835704039



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasBatchStrategy.java
##########
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared batch strategy param. */
+public interface HasBatchStrategy<T> extends WithParams<T> {
+    String COUNT_STRATEGY = "count";

Review comment:
       `"count"` represents the strategy to create mini batches with a fixed batch size. Other options to be added in future includes creating batches with a fixed sliding / tumbling window. I'll add a JavaDoc explaining `count` here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836068872



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       @zhipeng93 Do you think there is any reason that user would want to actually run OnlineKMeans with parallelism < batchSize, and if yes, could you provide some detail?
   
   And if no, could you explain why we should actually support running algorithm with `parallelism > batchSize`?
   
   I am also wondering what `consistent user experience` means here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836068872



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       @zhipeng93 Do you think there is any reason that user would want to actually run OnlineKMeans with parallelism > batchSize, and if yes, could you provide some detail?
   
   And if no, could you explain why we should actually support running algorithm with `parallelism > batchSize`?
   
   I am also wondering what `consistent user experience` means here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r835696729



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();

Review comment:
       Similarly, I changed the implementation to this
   ```java
   int div = values.length / downStreamParallelism;
   int mod = values.length % downStreamParallelism;
   
   int offset = 0;
   int i = 0;
   
   int size = div + 1;
   for (; i < mod; i++) {
       collector.collect(Arrays.copyOfRange(values, offset, offset + size));
       offset += size;
   }
   
   size = div;
   for (; i < downStreamParallelism; i++) {
       collector.collect(Arrays.copyOfRange(values, offset, offset + size));
       offset += size;
   }
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832772257



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySinkFunction.java
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SinkFunction} implementation that makes all collected records available for tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySinkFunction<T> extends RichSinkFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+
+    public InMemorySinkFunction() {
+        id = UUID.randomUUID();
+        queue = new LinkedBlockingQueue();
+        queueMap.put(id, queue);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        queue = queueMap.get(id);
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        queueMap.remove(id);
+    }
+
+    @Override
+    public void invoke(T value, Context context) {
+        if (!queue.offer(value)) {
+            throw new RuntimeException(
+                    "Failed to offer " + value + " to blocking queue " + id + ".");
+        }
+    }
+
+    public List<T> poll(int num) throws InterruptedException {

Review comment:
       `take()` would block infinitely if the Flink job fails to output expected number of records to this sink, which means relevant tests might never stop. `poll()` at least can unblock the thread and cause some exception afterwards with the returned null value. While in `InMemorySourceFunction` we may add dummy values to the queue to unblock `take()`, we cannot ask all developers to be careful about this in their test cases.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832808860



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansParams.java
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasDecayFactor;
+import org.apache.flink.ml.common.param.HasSeed;
+import org.apache.flink.ml.param.DoubleArrayParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link OnlineKMeans}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineKMeansParams<T>
+        extends HasBatchStrategy<T>, HasDecayFactor<T>, HasSeed<T>, KMeansModelParams<T> {
+    Param<String> INIT_MODE =
+            new StringParam(
+                    "initMode",
+                    "How to initialize the model data of the online KMeans algorithm. Supported options: 'random', 'direct'.",
+                    "random",
+                    ParamValidators.inArray("random", "direct"));
+
+    Param<Integer> DIMS =

Review comment:
       I agree. I'll make the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832896416



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int dim = centroids[0].size();
+            DenseVector counts = new DenseVector(k);
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts.values[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts.values[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            for (int i = 0; i < k; i++) {
+                if (counts.values[i] < 1e-5) {
+                    continue;
+                }
+                BLAS.scal(1.0 / counts.values[i], sums[i]);
+            }
+
+            output.collect(new StreamRecord<>(Tuple2.of(new KMeansModelData(sums), counts)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    private static class MiniBatchDistributor
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+        private int shift = 0;
+
+        private MiniBatchDistributor(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            // Reduce accumulated imbalance among distributed batches.
+            Collections.rotate(sizes, shift);

Review comment:
       For example, when `batchSize == 128 && parallelism == 3`, a global batch should be split into 43, 43, 42 records. This means that one of the subtasks would always receive 1 less record than others. The shift here aims to solve such problems.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832909011



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int dim = centroids[0].size();
+            DenseVector counts = new DenseVector(k);
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts.values[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts.values[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            for (int i = 0; i < k; i++) {
+                if (counts.values[i] < 1e-5) {
+                    continue;
+                }
+                BLAS.scal(1.0 / counts.values[i], sums[i]);
+            }
+
+            output.collect(new StreamRecord<>(Tuple2.of(new KMeansModelData(sums), counts)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    private static class MiniBatchDistributor
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+        private int shift = 0;
+
+        private MiniBatchDistributor(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            // Reduce accumulated imbalance among distributed batches.
+            Collections.rotate(sizes, shift);

Review comment:
       Could you explain what `problem` does it solve? For example, would we get better performance by making this shift?
   
   Note that even if we shift it, the time to process one global batch (i.e. one step) still depends on the time to process the largest local batch. A subtask can not continue to process the buffered batches until it receives the global model data from the last step.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r830051567



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model

Review comment:
       Sounds good.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r833165485



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int dim = centroids[0].size();
+            DenseVector counts = new DenseVector(k);
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts.values[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts.values[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            for (int i = 0; i < k; i++) {
+                if (counts.values[i] < 1e-5) {
+                    continue;
+                }
+                BLAS.scal(1.0 / counts.values[i], sums[i]);
+            }
+
+            output.collect(new StreamRecord<>(Tuple2.of(new KMeansModelData(sums), counts)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    private static class MiniBatchDistributor
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+        private int shift = 0;
+
+        private MiniBatchDistributor(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            // Reduce accumulated imbalance among distributed batches.
+            Collections.rotate(sizes, shift);

Review comment:
       Yes it seems that the general progress cannot advance faster with this shift. I'll remove it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r833172433



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))

Review comment:
       Hmm, the very first model data still has to be passed into `ModelDataGlobalUpdater`. As broadcast cannot be used here since the model data stream is unbounded, `ModelDataGlobalUpdater` must be `TwoInputOperator` that keeps receiving from model data stream.
   
   It could make things better to have `ModelDataGlobalUpdater` starts working so long as it receives enough local model data, ignoring all inputs from the feedback edge except the first one. This could help avoid global model data being a potential blocker of the whole process. What do you think?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#issuecomment-1078703906


   Thanks for the comments @lindong28 . I have updated the PR according to the comments.
   
   As @zhipeng93 might be reviewing the PR, I have not rebased the commits yet.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834995617



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
##########
@@ -47,8 +47,16 @@
 
     public DenseVector[] centroids;
 
+    public DenseVector weights;
+
+    public KMeansModelData(DenseVector[] centroids, DenseVector weights) {
+        this.centroids = centroids;
+        this.weights = weights;
+    }
+
     public KMeansModelData(DenseVector[] centroids) {
         this.centroids = centroids;
+        this.weights = new DenseVector(centroids.length);

Review comment:
       If this model data is only guaranteed to be used only by `OnlineKMeansModel`, I agree we can do that. 
   
   For the code that constructs a KMeansModelData instance, if it is not guaranteed that this instance is only used by `OnlineKMeansModel`, I suggested to still make this model data `self-contained` with valid weights.
   
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#issuecomment-1072229314


   Thanks for the comments @lindong28 . I have updated the PR and left responses to the comments.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832102967



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockMessageQueues.java
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A class that manages global message queues used in unit tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class MockMessageQueues {
+    private static final Map<String, BlockingQueue> queueMap = new HashMap<>();
+    private static long counter = 0;
+
+    public static synchronized String createMessageQueue() {
+        String id = String.valueOf(counter);
+        queueMap.put(id, new LinkedBlockingQueue<>());
+        counter++;
+        return id;
+    }
+
+    @SafeVarargs
+    public static <T> void offerAll(String id, T... values) throws InterruptedException {
+        for (T value : values) {
+            offer(id, value);
+        }
+    }
+
+    public static <T> void offer(String id, T value) throws InterruptedException {
+        offer(id, value, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> void offer(String id, T value, long timeout, TimeUnit unit)
+            throws InterruptedException {
+        boolean success = queueMap.get(id).offer(value, timeout, unit);
+        if (!success) {
+            throw new RuntimeException(
+                    "Failed to offer " + value + " to blocking queue " + id + ".");
+        }
+    }
+
+    public static <T> List<T> poll(String id, int num) throws InterruptedException {
+        List<T> result = new ArrayList<>();
+        for (int i = 0; i < num; i++) {
+            result.add(poll(id));
+        }
+        return result;
+    }
+
+    public static <T> T poll(String id) throws InterruptedException {
+        return poll(id, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> T poll(String id, long timeout, TimeUnit unit) throws InterruptedException {
+        T value = (T) queueMap.get(id).poll(timeout, unit);
+        if (value == null) {
+            throw new RuntimeException("Failed to poll next value from blocking queue " + id + ".");
+        }
+        return value;
+    }
+
+    public static void deleteBlockingQueue(String id) {

Review comment:
       I tried mocking the implementation of `InMemoryReporter` and introduced `InMemorySourceFunction` and `InMemorySinkFunction`. Users now only needs to create instances of these two classes in their test classes, and the static queues are hidden from users. The statically created queues will be automatically deleted when the source/sinks are closed, so users need not delete each of them or delete any `group` concept instances. As there seems not to be a built-in implementation of such classes in Flink, like `InMemoryReporter`, this is the most proper solution I can think of by now. Please have a look at them and see if it is a clean enough solution.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832765505



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,569 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ *
+ * <p>NOTE: This class's current naive implementation performs the training process in a
+ * single-threaded way. Correctness is not affected but there are performance issues.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));

Review comment:
       Given that the validator of BATCH_STRATEGY should already guarantee that `getBatchStrategy() == count`, would it be simpler to remove this check? This would also be more consistent with how NaiveBayes handles modelType.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansParams.java
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasDecayFactor;
+import org.apache.flink.ml.common.param.HasSeed;
+import org.apache.flink.ml.param.DoubleArrayParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link OnlineKMeans}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineKMeansParams<T>
+        extends HasBatchStrategy<T>, HasDecayFactor<T>, HasSeed<T>, KMeansModelParams<T> {
+    Param<String> INIT_MODE =
+            new StringParam(
+                    "initMode",
+                    "How to initialize the model data of the online KMeans algorithm. Supported options: 'random', 'direct'.",
+                    "random",
+                    ParamValidators.inArray("random", "direct"));
+
+    Param<Integer> DIMS =

Review comment:
       This parameter indicate the dimension of the input vector. Would it be more intuitive to name it `dim` instead of `dims`?
   
   If we agree to do so, we probably need to rename other related variables as well.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,569 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ *
+ * <p>NOTE: This class's current naive implementation performs the training process in a
+ * single-threaded way. Correctness is not affected but there are performance issues.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table... initModelDataTables) {

Review comment:
       Given that we know for sure `initModelDataTables.length == 1`, would it be simpler to just use `Table initModelDataTable` as input parameter?
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r833161087



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int dim = centroids[0].size();
+            DenseVector counts = new DenseVector(k);
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts.values[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts.values[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            for (int i = 0; i < k; i++) {
+                if (counts.values[i] < 1e-5) {
+                    continue;
+                }
+                BLAS.scal(1.0 / counts.values[i], sums[i]);
+            }
+
+            output.collect(new StreamRecord<>(Tuple2.of(new KMeansModelData(sums), counts)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    private static class MiniBatchDistributor

Review comment:
       Got it. And since the term "mini batch" would cause ambiguity according to other discussions, I'll rename the operators to `GlobalBatchCreator` and `LocalBatchDistributor`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r822304676



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {
+    private static final Map<String, Object> map = new HashMap<>();
+
+    private static long counter = 0;
+
+    /**
+     * Returns a prefix string to make sure key-value pairs created in different test cases would
+     * not have the same key.
+     */
+    public static synchronized String createNonDuplicatePrefix() {

Review comment:
       Could we simplify the `MockKVStore`'s functionality by removing this method and simply let the caller generate a random int64 value?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockMessageQueues.java
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A class that manages global message queues used in unit tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class MockMessageQueues {
+    private static final Map<String, BlockingQueue> queueMap = new HashMap<>();
+    private static long counter = 0;
+
+    public static synchronized String createMessageQueue() {
+        String id = String.valueOf(counter);
+        queueMap.put(id, new LinkedBlockingQueue<>());
+        counter++;
+        return id;
+    }
+
+    @SafeVarargs
+    public static <T> void offerAll(String id, T... values) throws InterruptedException {
+        for (T value : values) {
+            offer(id, value);
+        }
+    }
+
+    public static <T> void offer(String id, T value) throws InterruptedException {
+        offer(id, value, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> void offer(String id, T value, long timeout, TimeUnit unit)
+            throws InterruptedException {
+        boolean success = queueMap.get(id).offer(value, timeout, unit);
+        if (!success) {
+            throw new RuntimeException(
+                    "Failed to offer " + value + " to blocking queue " + id + ".");
+        }
+    }
+
+    public static <T> List<T> poll(String id, int num) throws InterruptedException {
+        List<T> result = new ArrayList<>();
+        for (int i = 0; i < num; i++) {
+            result.add(poll(id));
+        }
+        return result;
+    }
+
+    public static <T> T poll(String id) throws InterruptedException {
+        return poll(id, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> T poll(String id, long timeout, TimeUnit unit) throws InterruptedException {
+        T value = (T) queueMap.get(id).poll(timeout, unit);
+        if (value == null) {
+            throw new RuntimeException("Failed to poll next value from blocking queue " + id + ".");
+        }
+        return value;
+    }
+
+    public static void deleteBlockingQueue(String id) {

Review comment:
       Given that we always remove all queues when this method is called, would it be simpler to replace this method with `clear()` which removes all queues?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeansModel.java
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * StreamingKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link StreamingKMeans}.
+ */
+public class StreamingKMeansModel
+        implements Model<StreamingKMeansModel>, KMeansModelParams<StreamingKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public StreamingKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public StreamingKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> cache = new ArrayList<>();

Review comment:
       It is not clear what this `cache` represents. And usually cache is used for improving performance.
   
   How about renaming this variable to `bufferedPoints`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model

Review comment:
       Would it be useful to provide more detail of the algorithm here, similar to that of Spark's StreamingKMeans Javadoc?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockMessageQueues.java
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A class that manages global message queues used in unit tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class MockMessageQueues {

Review comment:
       Would it be more self-explanatory to rename it `GlobalBlockingQueues`?
   
   Given that this class represents real queues, it is not clear what `mock` means here.

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {

Review comment:
       It looks like `MockKVStore` is only used to store the metrics reported by `TestMetricReporter`.
   
   Would it be simpler to remove `MockKVStore` and let `TestMetricReporter` expose `static <T> T get(String key)`?
   
   And if `TestMetricReporter` is expected to be singleton, would it be more self-explanatory to rename it `SingletonInMemoryMetricReporter`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDecayFactor.java
##########
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared decay factor param. */
+public interface HasDecayFactor<T> extends WithParams<T> {
+    Param<Double> DECAY_FACTOR =
+            new DoubleParam(
+                    "decayFactor",
+                    "The forgetfulness of the previous centroids.",
+                    0.,
+                    ParamValidators.gtEq(0));

Review comment:
       Does this value need to be smaller than 1?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/StreamingKMeansTest.java
##########
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeans;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.MockKVStore;
+import org.apache.flink.ml.util.MockMessageQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link StreamingKMeans} and {@link StreamingKMeansModel}. */
+public class StreamingKMeansTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private String trainId;
+    private String predictId;
+    private String outputId;
+    private String modelDataId;
+    private String metricReporterPrefix;
+    private String modelDataVersionGaugeKey;
+    private String currentModelDataVersion;
+    private List<String> queueIds;
+    private List<String> kvStoreKeys;
+    private List<JobClient> clients;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() {
+        metricReporterPrefix = MockKVStore.createNonDuplicatePrefix();
+        kvStoreKeys = new ArrayList<>();
+        kvStoreKeys.add(metricReporterPrefix);
+
+        modelDataVersionGaugeKey =
+                TestMetricReporter.getKey(metricReporterPrefix, "modelDataVersion");
+
+        currentModelDataVersion = "0";
+
+        trainId = MockMessageQueues.createMessageQueue();
+        predictId = MockMessageQueues.createMessageQueue();
+        outputId = MockMessageQueues.createMessageQueue();
+        modelDataId = MockMessageQueues.createMessageQueue();
+        queueIds = new ArrayList<>();
+        queueIds.addAll(Arrays.asList(trainId, predictId, outputId, modelDataId));
+
+        clients = new ArrayList<>();
+
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        config.setString("metrics.reporters", "test_reporter");
+        config.setString(
+                "metrics.reporter.test_reporter.class", TestMetricReporter.class.getName());
+        config.setString("metrics.reporter.test_reporter.interval", "100 MILLISECONDS");
+        config.setString("metrics.reporter.test_reporter.prefix", metricReporterPrefix);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(trainId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(predictId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() {
+        for (JobClient client : clients) {
+            try {
+                client.cancel();
+            } catch (IllegalStateException e) {
+                if (!e.getMessage()
+                        .equals("MiniCluster is not yet running or has already been shut down.")) {
+                    throw e;
+                }
+            }
+        }
+        clients.clear();
+
+        for (String queueId : queueIds) {
+            MockMessageQueues.deleteBlockingQueue(queueId);
+        }
+        queueIds.clear();
+
+        for (String key : kvStoreKeys) {
+            MockKVStore.remove(key);
+        }
+        kvStoreKeys.clear();
+    }
+
+    /** Adds sinks for StreamingKMeansModel's transform output and model data. */
+    private void configModelSink(StreamingKMeansModel streamingModel) {

Review comment:
       Could we make this method name more self-explanatory by indicating that it involves the `transform(...)` operation?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);

Review comment:
       Is there a way to enable parallel processing of input streams, instead of limiting its parallelism to 1?
   
   If it is possible that hard to implement, I think it is acceptable to use that parallelism 1 in this PR. If so, since the parallelism is related to the performance of this class, could we explain it in the Java doc of this class?
   
   
   
   

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);

Review comment:
       Does this call guarantee that all values from `initModelDataStream` comes before `finalModelDataStream` in the output stream?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/StreamingKMeansTest.java
##########
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeans;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.MockKVStore;
+import org.apache.flink.ml.util.MockMessageQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link StreamingKMeans} and {@link StreamingKMeansModel}. */
+public class StreamingKMeansTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private String trainId;
+    private String predictId;
+    private String outputId;
+    private String modelDataId;
+    private String metricReporterPrefix;
+    private String modelDataVersionGaugeKey;
+    private String currentModelDataVersion;
+    private List<String> queueIds;
+    private List<String> kvStoreKeys;
+    private List<JobClient> clients;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() {
+        metricReporterPrefix = MockKVStore.createNonDuplicatePrefix();
+        kvStoreKeys = new ArrayList<>();
+        kvStoreKeys.add(metricReporterPrefix);
+
+        modelDataVersionGaugeKey =
+                TestMetricReporter.getKey(metricReporterPrefix, "modelDataVersion");
+
+        currentModelDataVersion = "0";
+
+        trainId = MockMessageQueues.createMessageQueue();
+        predictId = MockMessageQueues.createMessageQueue();
+        outputId = MockMessageQueues.createMessageQueue();
+        modelDataId = MockMessageQueues.createMessageQueue();
+        queueIds = new ArrayList<>();
+        queueIds.addAll(Arrays.asList(trainId, predictId, outputId, modelDataId));
+
+        clients = new ArrayList<>();
+
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        config.setString("metrics.reporters", "test_reporter");
+        config.setString(
+                "metrics.reporter.test_reporter.class", TestMetricReporter.class.getName());
+        config.setString("metrics.reporter.test_reporter.interval", "100 MILLISECONDS");
+        config.setString("metrics.reporter.test_reporter.prefix", metricReporterPrefix);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(trainId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(predictId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() {
+        for (JobClient client : clients) {
+            try {
+                client.cancel();
+            } catch (IllegalStateException e) {
+                if (!e.getMessage()
+                        .equals("MiniCluster is not yet running or has already been shut down.")) {
+                    throw e;
+                }
+            }
+        }
+        clients.clear();
+
+        for (String queueId : queueIds) {
+            MockMessageQueues.deleteBlockingQueue(queueId);
+        }
+        queueIds.clear();
+
+        for (String key : kvStoreKeys) {
+            MockKVStore.remove(key);
+        }
+        kvStoreKeys.clear();
+    }
+
+    /** Adds sinks for StreamingKMeansModel's transform output and model data. */
+    private void configModelSink(StreamingKMeansModel streamingModel) {
+        Table outputTable = streamingModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(new MockSinkFunction<>(outputId));
+
+        Table modelDataTable = streamingModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(new MockSinkFunction<>(modelDataId));
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (!MockKVStore.containsKey(modelDataVersionGaugeKey)) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion = MockKVStore.get(modelDataVersionGaugeKey);
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        MockMessageQueues.offerAll(predictId, StreamingKMeansTest.predictData);
+        List<Row> rawResult =
+                MockMessageQueues.poll(outputId, StreamingKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        StreamingKMeans streamingKMeans = new StreamingKMeans();
+        Assert.assertEquals("features", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, streamingKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(2, streamingKMeans.getK());
+        Assert.assertEquals(1, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(1, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0., streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(StreamingKMeans.class.getName().hashCode(), streamingKMeans.getSeed());
+
+        streamingKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(3, streamingKMeans.getK());
+        Assert.assertEquals(5, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(5, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0.25, streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", streamingKMeans.getInitMode());
+        Assert.assertEquals(100, streamingKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithOfflineKMeans() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setDecayFactor(100.0)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        KMeans offlineKMeans =

Review comment:
       Would it be simpler to name variable after its class name?
   
   Currently we use `KMeans` to refer to offline algorithm and `StreamingKMeans` to refer to online algorithm. Thus we already expect users to understand that `KMeans` refers to the offline algorithm.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());

Review comment:
       If initMode is "random", should we assert that `initModelDataTable  == null`. And if initMode is `direct`, should we assert that `initModelDataTable != null`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans

Review comment:
       Should this be renamed `OnlineKMeans`?
   
   Same for other class names.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();

Review comment:
       Is it correct to clear the modelDataState here?
   
   By doing so, it means that a model data can only be used to process input points received before it, but not input points received after it, right?
   
   I think it is not guaranteed that there is exactly one batch of data between two model data records.

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -207,25 +211,21 @@ public void testSaveLoadAndPredict() throws Exception {
         KMeansModel loadedModel =
                 StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
         Table output = loadedModel.transform(dataTable)[0];
-        assertEquals(
-                Collections.singletonList("centroids"),
-                loadedModel.getModelData()[0].getResolvedSchema().getColumnNames());
         assertEquals(
                 Arrays.asList("features", "prediction"),
                 output.getResolvedSchema().getColumnNames());
 
+        List<Row> results = IteratorUtils.toList(output.execute().collect());
         List<Set<DenseVector>> actualGroups =
-                executeAndCollect(output, kmeans.getFeaturesCol(), kmeans.getPredictionCol());
+                groupFeaturesByPrediction(
+                        results, kmeans.getFeaturesCol(), kmeans.getPredictionCol());
         assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
     }
 
     @Test
     public void testGetModelData() throws Exception {
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
         KMeansModel model = kmeans.fit(dataTable);
-        assertEquals(
-                Collections.singletonList("centroids"),

Review comment:
       Hmm.. why do we remove this check?
   
   After removing this, is there a test that verifies the model data table schema is correct?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();
+
+            List<DenseVector> weightsList = IteratorUtils.toList(weightsState.get().iterator());
+            if (weightsList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + weightsList.size()
+                                + " list of weights in this round");
+            }
+            DenseVector weights = weightsList.get(0);
+            weightsState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int dims = centroids[0].size();
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dims);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dims; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor, weights);
+            for (int i = 0; i < k; i++) {
+                DenseVector centroid = centroids[i];
+
+                double updatedWeight = weights.values[i] + counts[i];

Review comment:
       Would it be simpler to replace `updatedWeight` with `weights.values[i]`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();
+
+            List<DenseVector> weightsList = IteratorUtils.toList(weightsState.get().iterator());
+            if (weightsList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + weightsList.size()
+                                + " list of weights in this round");
+            }
+            DenseVector weights = weightsList.get(0);
+            weightsState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);

Review comment:
       Hmm... why do we add back these points? Would it be simpler to just process all buffered points in this method call?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/StreamingKMeansTest.java
##########
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeans;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.MockKVStore;
+import org.apache.flink.ml.util.MockMessageQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link StreamingKMeans} and {@link StreamingKMeansModel}. */
+public class StreamingKMeansTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private String trainId;
+    private String predictId;
+    private String outputId;
+    private String modelDataId;
+    private String metricReporterPrefix;
+    private String modelDataVersionGaugeKey;
+    private String currentModelDataVersion;
+    private List<String> queueIds;
+    private List<String> kvStoreKeys;
+    private List<JobClient> clients;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() {
+        metricReporterPrefix = MockKVStore.createNonDuplicatePrefix();
+        kvStoreKeys = new ArrayList<>();
+        kvStoreKeys.add(metricReporterPrefix);
+
+        modelDataVersionGaugeKey =
+                TestMetricReporter.getKey(metricReporterPrefix, "modelDataVersion");
+
+        currentModelDataVersion = "0";
+
+        trainId = MockMessageQueues.createMessageQueue();
+        predictId = MockMessageQueues.createMessageQueue();
+        outputId = MockMessageQueues.createMessageQueue();
+        modelDataId = MockMessageQueues.createMessageQueue();
+        queueIds = new ArrayList<>();
+        queueIds.addAll(Arrays.asList(trainId, predictId, outputId, modelDataId));
+
+        clients = new ArrayList<>();
+
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        config.setString("metrics.reporters", "test_reporter");
+        config.setString(
+                "metrics.reporter.test_reporter.class", TestMetricReporter.class.getName());
+        config.setString("metrics.reporter.test_reporter.interval", "100 MILLISECONDS");
+        config.setString("metrics.reporter.test_reporter.prefix", metricReporterPrefix);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(trainId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(predictId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() {
+        for (JobClient client : clients) {
+            try {
+                client.cancel();
+            } catch (IllegalStateException e) {
+                if (!e.getMessage()
+                        .equals("MiniCluster is not yet running or has already been shut down.")) {
+                    throw e;
+                }
+            }
+        }
+        clients.clear();
+
+        for (String queueId : queueIds) {
+            MockMessageQueues.deleteBlockingQueue(queueId);
+        }
+        queueIds.clear();
+
+        for (String key : kvStoreKeys) {
+            MockKVStore.remove(key);
+        }
+        kvStoreKeys.clear();
+    }
+
+    /** Adds sinks for StreamingKMeansModel's transform output and model data. */
+    private void configModelSink(StreamingKMeansModel streamingModel) {
+        Table outputTable = streamingModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(new MockSinkFunction<>(outputId));
+
+        Table modelDataTable = streamingModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(new MockSinkFunction<>(modelDataId));
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (!MockKVStore.containsKey(modelDataVersionGaugeKey)) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion = MockKVStore.get(modelDataVersionGaugeKey);
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        MockMessageQueues.offerAll(predictId, StreamingKMeansTest.predictData);
+        List<Row> rawResult =
+                MockMessageQueues.poll(outputId, StreamingKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        StreamingKMeans streamingKMeans = new StreamingKMeans();
+        Assert.assertEquals("features", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, streamingKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(2, streamingKMeans.getK());
+        Assert.assertEquals(1, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(1, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0., streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(StreamingKMeans.class.getName().hashCode(), streamingKMeans.getSeed());
+
+        streamingKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(3, streamingKMeans.getK());
+        Assert.assertEquals(5, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(5, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0.25, streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", streamingKMeans.getInitMode());
+        Assert.assertEquals(100, streamingKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithOfflineKMeans() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setDecayFactor(100.0)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setDims(2)
+                        .setBatchSize(6)
+                        .setInitWeights(new Double[] {0., 0.});
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeans loadedKMeans =
+                StageTestUtils.saveAndReload(
+                        env, streamingKMeans, tempFolder.newFolder().getAbsolutePath());
+
+        StreamingKMeansModel streamingModel = loadedKMeans.fit(trainTable);
+
+        String modelDataPassId = MockMessageQueues.createMessageQueue();
+        queueIds.add(modelDataPassId);
+
+        String modelSavePath = tempFolder.newFolder().getAbsolutePath();
+        streamingModel.save(modelSavePath);
+        KMeansModelData.getModelDataStream(streamingModel.getModelData()[0])
+                .addSink(new MockSinkFunction<>(modelDataPassId));
+        clients.add(env.executeAsync());

Review comment:
       Hmm... what is the issue with save() if we use `executeAsync()`?
   
   `KafkaConsumerTestBase::runMultipleSourcesOnePartitionExactlyOnceTest()` also uses unbounded source and verifies that the expected sequence of values are generated. Maybe follow that practice as example?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829811996



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {

Review comment:
       Got it. I'll also apply it to `KMeans`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r830429315



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDecayFactor.java
##########
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared decay factor param. */
+public interface HasDecayFactor<T> extends WithParams<T> {
+    Param<Double> DECAY_FACTOR =
+            new DoubleParam(
+                    "decayFactor",
+                    "The forgetfulness of the previous centroids.",
+                    0.,
+                    ParamValidators.gtEq(0));

Review comment:
       Sounds good. Let's limit to be smaller than 1 for now.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDecayFactor.java
##########
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared decay factor param. */
+public interface HasDecayFactor<T> extends WithParams<T> {
+    Param<Double> DECAY_FACTOR =
+            new DoubleParam(
+                    "decayFactor",
+                    "The forgetfulness of the previous centroids.",
+                    0.,
+                    ParamValidators.gtEq(0));

Review comment:
       Sounds good. Let's limit it to be smaller than 1 for now.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829714815



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);

Review comment:
       I think it is hard to achieve for now. We need to create mini batches of fixed batch size from train data, but if the parallelism is larger than 1, we do not have a mechanism to count the total number of records received by each subtask.
   
   One possible solution I have wanted to propose is to insert barrier into the train data stream, so that even if train data would be distributed on different subtasks, the subtasks still knows when to finish the current batch so long as it can receive barrier. We have not got a change to discuss this problem and possible solutions offline.
   
   For now I prefer to still have this limit in this PR. I'll add relevant notices in its Javadoc.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829711423



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockMessageQueues.java
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A class that manages global message queues used in unit tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class MockMessageQueues {
+    private static final Map<String, BlockingQueue> queueMap = new HashMap<>();
+    private static long counter = 0;
+
+    public static synchronized String createMessageQueue() {
+        String id = String.valueOf(counter);
+        queueMap.put(id, new LinkedBlockingQueue<>());
+        counter++;
+        return id;
+    }
+
+    @SafeVarargs
+    public static <T> void offerAll(String id, T... values) throws InterruptedException {
+        for (T value : values) {
+            offer(id, value);
+        }
+    }
+
+    public static <T> void offer(String id, T value) throws InterruptedException {
+        offer(id, value, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> void offer(String id, T value, long timeout, TimeUnit unit)
+            throws InterruptedException {
+        boolean success = queueMap.get(id).offer(value, timeout, unit);
+        if (!success) {
+            throw new RuntimeException(
+                    "Failed to offer " + value + " to blocking queue " + id + ".");
+        }
+    }
+
+    public static <T> List<T> poll(String id, int num) throws InterruptedException {
+        List<T> result = new ArrayList<>();
+        for (int i = 0; i < num; i++) {
+            result.add(poll(id));
+        }
+        return result;
+    }
+
+    public static <T> T poll(String id) throws InterruptedException {
+        return poll(id, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> T poll(String id, long timeout, TimeUnit unit) throws InterruptedException {
+        T value = (T) queueMap.get(id).poll(timeout, unit);
+        if (value == null) {
+            throw new RuntimeException("Failed to poll next value from blocking queue " + id + ".");
+        }
+        return value;
+    }
+
+    public static void deleteBlockingQueue(String id) {

Review comment:
       In cases when there are multiple test classes running in parallel, I'm afraid that calling a `clear()` method in one test case would affect the process of the others. Thus I would prefer to have each test class specifies and deletes its own queues.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r835702554



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(

Review comment:
       According to offline discussion, I'll add a `KMeansUtils.generateRandomModelData()` method and remove the `setRandomCentroids` method from `OnlineKMeans`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834984079



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {

Review comment:
       +1




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836068872



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       @zhipeng93 Do you think there is any reason that user would want to actually run OnlinkeKMeans with parallelism < batchSize, and if yes, could you provide some detail?
   
   And if no, could you explain why we should actually support this?
   
   I am also wondering what `consistent user experience` means here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836068872



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       @zhipeng93 Do you think there is any reason that user would want to actually run OnlineKMeans with parallelism < batchSize, and if yes, could you provide some detail?
   
   And if no, could you explain why we should actually support running algorithm with `parallelism < batchSize`?
   
   I am also wondering what `consistent user experience` means here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834927926



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,464 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();

Review comment:
       OK. I'll add it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829744134



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model

Review comment:
       I agree. I'll reference to Spark's documentation and add similar JavaDocs here.
   
   Besides, Spark's JavaDoc is mainly on `StreamingKMeansModel`, rather than `StreamingKMeans`. I think it should be `StreamingKMeans` that is explained more in detail, and similar to Spark's JavaDoc, my added documentation would be mainly about the training process, so I'll add the detailed documentations on `OnlineKMeans`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834923924



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {

Review comment:
       The introduction of `k` in `OnlineKMeans` might cause the following problem.
   
   - `setK()` must be invoked before `setRandomCentroids()`.
   - If users use `setInitialModelData`, then we need to check whether the initial model data would provide exactly `k` centroids.
   
   Because of these problems, I am not sure whether we should add the K parameter on `OnlineKMeans`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834913368



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,464 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();

Review comment:
       It seems that `InMemoryReporter` cannot be used together with `AbstractTestBase`. `InMemoryReporter` needs to be registered into the mini cluster's configuration, but `AbstractTestBase` would start the cluster before I can do that.
   
   In flink 1.14.0 `InMemoryReporter` is used in `SourceMetricsITCase` and `SinkMetricsITCase`, and both of these classes extends `TestLogger`, a parent class of `AbstractTestBase`. They both use mini cluster, but the cluster is created in `@Before` methods on their own. I'd like to follow this practice.
   
   In the latest Flink code, `SourceMetricsITCase` can create in-memory reporter and mini-cluster just once for each test class, but it seems not applicable yet when Flink ML still depends on 1.14.0.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836000412



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();

Review comment:
       As I see, adding algorithms in flink-ml has the following goals:
   - A technically correctly implementation, but may only work on small dataset
   - A technically correctly implementation, and also work on large dataset (or production usage).
   
   For now, if we do not checkpoint the buffered data points, we perharps cannot gurantee the correctness even for a small dataset.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836001627



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       I guess what you want to express is to make sure `parallelism <= batchSize`, and to throw exception if `parallelism > batchSize`. I cannot think of a valid reason to do this check.
   
   I think even if the number of elements in a batch is smaller than parallelism, the training process should still function correctly, because when batches are created according to time window, the operator would have no idea about how many elements are in each batch.
   
   If the case above is correct, the purpose if doing the `parallelism <= batchSize`, as I can think of, is performance issues. But I think the same issue also exists when `batchSize` is just a bit larger than `parallelism`, that each subtask is only allocated with 1 or 2 training data. In this situation, a strict `parallelism <= batchSize` threshold won't solve the performance problem.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836015733



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       OK. I'll add the check and relevant test case.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836084920



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       As pointed out by @yunfengzhou-hub , a use case could be:
   > I think even if the number of elements in a batch is smaller than parallelism, the training process should still function correctly, because when batches are created according to time window, the operator would have no idea about how many elements are in each batch.
   
   By `consistent user experience`, I mean that given an input of training set, users should be able to run the job on one worker or arbitrary number of workers.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829629945



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {

Review comment:
       Maybe use `OperatorStateUtils.getUniqueElement` for simplicity?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r830429261



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockMessageQueues.java
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A class that manages global message queues used in unit tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class MockMessageQueues {
+    private static final Map<String, BlockingQueue> queueMap = new HashMap<>();
+    private static long counter = 0;
+
+    public static synchronized String createMessageQueue() {
+        String id = String.valueOf(counter);
+        queueMap.put(id, new LinkedBlockingQueue<>());
+        counter++;
+        return id;
+    }
+
+    @SafeVarargs
+    public static <T> void offerAll(String id, T... values) throws InterruptedException {
+        for (T value : values) {
+            offer(id, value);
+        }
+    }
+
+    public static <T> void offer(String id, T value) throws InterruptedException {
+        offer(id, value, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> void offer(String id, T value, long timeout, TimeUnit unit)
+            throws InterruptedException {
+        boolean success = queueMap.get(id).offer(value, timeout, unit);
+        if (!success) {
+            throw new RuntimeException(
+                    "Failed to offer " + value + " to blocking queue " + id + ".");
+        }
+    }
+
+    public static <T> List<T> poll(String id, int num) throws InterruptedException {
+        List<T> result = new ArrayList<>();
+        for (int i = 0; i < num; i++) {
+            result.add(poll(id));
+        }
+        return result;
+    }
+
+    public static <T> T poll(String id) throws InterruptedException {
+        return poll(id, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> T poll(String id, long timeout, TimeUnit unit) throws InterruptedException {
+        T value = (T) queueMap.get(id).poll(timeout, unit);
+        if (value == null) {
+            throw new RuntimeException("Failed to poll next value from blocking queue " + id + ".");
+        }
+        return value;
+    }
+
+    public static void deleteBlockingQueue(String id) {

Review comment:
       I see. Do you think the logic would be simpler by introducing a `group` concept, similar to that of metric group?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r830429072



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {

Review comment:
       I see. Do you think we can replace `TestMetricReporter ` with Flink's InMemoryReporter? I left more details in another comment.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r830429072



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {

Review comment:
       I see. Do you think we can replace `TestMetricReporter ` with Flink's InMemoryReporter?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r831989288



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();
+
+            List<DenseVector> weightsList = IteratorUtils.toList(weightsState.get().iterator());
+            if (weightsList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + weightsList.size()
+                                + " list of weights in this round");
+            }
+            DenseVector weights = weightsList.get(0);
+            weightsState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);

Review comment:
       Same as the previous comment, having a single-threaded operator to distribute mini batch solves this discussion.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829707954



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/StreamingKMeansTest.java
##########
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeans;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.MockKVStore;
+import org.apache.flink.ml.util.MockMessageQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link StreamingKMeans} and {@link StreamingKMeansModel}. */
+public class StreamingKMeansTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private String trainId;
+    private String predictId;
+    private String outputId;
+    private String modelDataId;
+    private String metricReporterPrefix;
+    private String modelDataVersionGaugeKey;
+    private String currentModelDataVersion;
+    private List<String> queueIds;
+    private List<String> kvStoreKeys;
+    private List<JobClient> clients;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() {
+        metricReporterPrefix = MockKVStore.createNonDuplicatePrefix();
+        kvStoreKeys = new ArrayList<>();
+        kvStoreKeys.add(metricReporterPrefix);
+
+        modelDataVersionGaugeKey =
+                TestMetricReporter.getKey(metricReporterPrefix, "modelDataVersion");
+
+        currentModelDataVersion = "0";
+
+        trainId = MockMessageQueues.createMessageQueue();
+        predictId = MockMessageQueues.createMessageQueue();
+        outputId = MockMessageQueues.createMessageQueue();
+        modelDataId = MockMessageQueues.createMessageQueue();
+        queueIds = new ArrayList<>();
+        queueIds.addAll(Arrays.asList(trainId, predictId, outputId, modelDataId));
+
+        clients = new ArrayList<>();
+
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        config.setString("metrics.reporters", "test_reporter");
+        config.setString(
+                "metrics.reporter.test_reporter.class", TestMetricReporter.class.getName());
+        config.setString("metrics.reporter.test_reporter.interval", "100 MILLISECONDS");
+        config.setString("metrics.reporter.test_reporter.prefix", metricReporterPrefix);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(trainId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(predictId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() {
+        for (JobClient client : clients) {
+            try {
+                client.cancel();
+            } catch (IllegalStateException e) {
+                if (!e.getMessage()
+                        .equals("MiniCluster is not yet running or has already been shut down.")) {
+                    throw e;
+                }
+            }
+        }
+        clients.clear();
+
+        for (String queueId : queueIds) {
+            MockMessageQueues.deleteBlockingQueue(queueId);
+        }
+        queueIds.clear();
+
+        for (String key : kvStoreKeys) {
+            MockKVStore.remove(key);
+        }
+        kvStoreKeys.clear();
+    }
+
+    /** Adds sinks for StreamingKMeansModel's transform output and model data. */
+    private void configModelSink(StreamingKMeansModel streamingModel) {
+        Table outputTable = streamingModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(new MockSinkFunction<>(outputId));
+
+        Table modelDataTable = streamingModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(new MockSinkFunction<>(modelDataId));
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (!MockKVStore.containsKey(modelDataVersionGaugeKey)) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion = MockKVStore.get(modelDataVersionGaugeKey);
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        MockMessageQueues.offerAll(predictId, StreamingKMeansTest.predictData);
+        List<Row> rawResult =
+                MockMessageQueues.poll(outputId, StreamingKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        StreamingKMeans streamingKMeans = new StreamingKMeans();
+        Assert.assertEquals("features", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, streamingKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(2, streamingKMeans.getK());
+        Assert.assertEquals(1, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(1, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0., streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(StreamingKMeans.class.getName().hashCode(), streamingKMeans.getSeed());
+
+        streamingKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(3, streamingKMeans.getK());
+        Assert.assertEquals(5, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(5, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0.25, streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", streamingKMeans.getInitMode());
+        Assert.assertEquals(100, streamingKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithOfflineKMeans() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setDecayFactor(100.0)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setDims(2)
+                        .setBatchSize(6)
+                        .setInitWeights(new Double[] {0., 0.});
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeans loadedKMeans =
+                StageTestUtils.saveAndReload(
+                        env, streamingKMeans, tempFolder.newFolder().getAbsolutePath());
+
+        StreamingKMeansModel streamingModel = loadedKMeans.fit(trainTable);
+
+        String modelDataPassId = MockMessageQueues.createMessageQueue();
+        queueIds.add(modelDataPassId);
+
+        String modelSavePath = tempFolder.newFolder().getAbsolutePath();
+        streamingModel.save(modelSavePath);
+        KMeansModelData.getModelDataStream(streamingModel.getModelData()[0])
+                .addSink(new MockSinkFunction<>(modelDataPassId));
+        clients.add(env.executeAsync());

Review comment:
       In current `Pipeline`/`Graph`'s `save()` method, I think we have implicitly assumed that calling `execute()` after `save()` should always unblock the process, while this is not true if online algorithms are involved. If this is not the case then using `executeAsync()` would not be an issue.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836004126



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -177,11 +177,20 @@ public void testFewerDistinctPointsThanCluster() {
         KMeans kmeans = new KMeans().setK(2);
         KMeansModel model = kmeans.fit(input);
         Table output = model.transform(input)[0];
-        List<Set<DenseVector>> expectedGroups =
-                Collections.singletonList(Collections.singleton(Vectors.dense(0.0, 0.1)));
-        List<Set<DenseVector>> actualGroups =
-                executeAndCollect(output, kmeans.getFeaturesCol(), kmeans.getPredictionCol());
-        assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+
+        try {

Review comment:
       Following Spark's implementation, I believe expected interpretation is `The max number of clusters to create...`. It also seems to make user's lifer easier to allowing the algorithm to automatically handle the case where the number of unique input data is less than `K`.
   
   @zhipeng93 what do you think?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836068872



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       @zhipeng93 Do you think there is any reason that user would want to actually run OnlineKMeans with parallelism > batchSize, and if yes, could you provide some details?
   
   And if no, could you explain why we should support running algorithm with `parallelism > batchSize`?
   
   I am also wondering what `consistent user experience` means here.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       @zhipeng93 Do you think there is any reason that user would want to actually run OnlineKMeans with parallelism > batchSize, and if yes, could you provide more detail?
   
   And if no, could you explain why we should support running algorithm with `parallelism > batchSize`?
   
   I am also wondering what `consistent user experience` means here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836071635



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();

Review comment:
       OK. I'll add a `ListState` to save buffered points during checkpoint.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r833161087



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int dim = centroids[0].size();
+            DenseVector counts = new DenseVector(k);
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts.values[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts.values[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            for (int i = 0; i < k; i++) {
+                if (counts.values[i] < 1e-5) {
+                    continue;
+                }
+                BLAS.scal(1.0 / counts.values[i], sums[i]);
+            }
+
+            output.collect(new StreamRecord<>(Tuple2.of(new KMeansModelData(sums), counts)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    private static class MiniBatchDistributor

Review comment:
       Got it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836002700



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, k, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    /**
+     * Operator that collects a KMeansModelData from each upstream subtask, and outputs the weight
+     * average of collected model data.
+     */
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    /**
+     * An operator that updates KMeans model data locally. It mainly does the following operations.
+     *
+     * <ul>
+     *   <li>Finds the closest centroid id (cluster) of the input points
+     *   <li>Computes the new centroids from the average of input points that belongs to the same
+     *       cluster
+     *   <li>Computes the weighted average of current and new centroids. The weight of a new
+     *       centroid is the number of input points that belong to this cluster. The weight of a
+     *       current centroid is its original weight scaled by $ decayFactor / parallelism $.
+     *   <li>Generates new model data from the weighted average of centroids, and the sum of
+     *       weights.
+     * </ul>
+     */
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchDataState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, int k, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            Preconditions.checkArgument(modelDataRecord.getValue().centroids.length == k);
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);

Review comment:
       `hasNext() == true` guarantees that there is 1 or more elements in the state, but `getUniqueElement` requires that there should be only 1 element in the state, so the `orElseThrow` would still be triggered if there are more than 1 elements.
   
   With that said, I think the `NullPointerException` might be improper here. I'll change it to `IllegalStateException`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub removed a comment on pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub removed a comment on pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#issuecomment-1057816374


   Hi @lindong28, I have created the PR for the streaming KMeans algorithm. Would you mind help reviewing this PR?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#issuecomment-1081346919


   @lindong28 Thanks for the review. This PR looks good to me~


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836003024



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();

Review comment:
       I think our long term goal is `A technically correctly implementation, and also work on large dataset in production`.
   
   If our algorithm only works on so-called small dataset, given that there is no guarantee on what is `small`, it effectively means our algorithm can not be used in production. If our algorithm can not be used in production, it won't matter whether we provide checkpoint or not.
   
   Given that we will anyway need to make our algorithm usable in production, it is necessary to address the TODO sooner or later. Then it seems simpler not to add the checkpoint code which we will throw away in the near future.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836040424



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();

Review comment:
       I would suggest we should check in the technically correct implementation at each step. If it is useless later, we then delete it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836084920



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       As pointed out by @yunfengzhou-hub , a use case could be:
   > I think even if the number of elements in a batch is smaller than parallelism, the training process should still function correctly, because when batches are created according to time window, the operator would have no idea about how many elements are in each batch.
   
   By `consistent user experience`, I mean that given a input of training set, users should be able to run it on one worker or arbitrary number of workers.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r830428974



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {
+    private static final Map<String, Object> map = new HashMap<>();
+
+    private static long counter = 0;
+
+    /**
+     * Returns a prefix string to make sure key-value pairs created in different test cases would
+     * not have the same key.
+     */
+    public static synchronized String createNonDuplicatePrefix() {

Review comment:
       In practice it is extremely unlikely that two randomly generated int64 values collide with each other. This is called `statistically unique`.
   
   According to the Javadoc of Flink's JobID, the JobID is statistically unique. And Flink already relies on this kind of uniqueness to differentiate multiple Flink jobs in the same Flink session in production. It seems that there is no need for strict uniqueness in our tests?
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832802109



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private String currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = "0";
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void configTransformAndSink(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion =
+                    String.valueOf(
+                            reporter.findMetrics("modelDataVersion").values().stream()
+                                    .map(x -> Integer.parseInt(((Gauge<String>) x).getValue()))
+                                    .min(Integer::compareTo)
+                                    .orElse(Integer.parseInt(currentModelDataVersion)));
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        predictSource.offerAll(OnlineKMeansTest.predictData);
+        List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        OnlineKMeans onlineKMeans = new OnlineKMeans();
+        Assert.assertEquals("features", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", onlineKMeans.getInitMode());
+        Assert.assertEquals(2, onlineKMeans.getK());
+        Assert.assertEquals(1, onlineKMeans.getDims());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(1, onlineKMeans.getBatchSize());
+        Assert.assertEquals(0., onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", onlineKMeans.getInitMode());
+        Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), onlineKMeans.getSeed());
+
+        onlineKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals(3, onlineKMeans.getK());
+        Assert.assertEquals(5, onlineKMeans.getDims());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(5, onlineKMeans.getBatchSize());
+        Assert.assertEquals(0.25, onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", onlineKMeans.getInitMode());
+        Assert.assertEquals(100, onlineKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);
+        configTransformAndSink(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        waitInitModelDataSetup();
+
+        trainSource.offerAll(trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+
+        trainSource.offerAll(trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithKMeans() throws Exception {
+        KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel kMeansModel = kMeans.fit(offlineTrainTable);
+
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans(kMeansModel.getModelData())
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);
+        configTransformAndSink(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+
+        trainSource.offerAll(trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel kMeansModel = kMeans.fit(offlineTrainTable);
+
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans(kMeansModel.getModelData())
+                        .setDims(2)
+                        .setInitWeights(new Double[] {3., 3.})
+                        .setDecayFactor(0.5)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);
+        configTransformAndSink(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        modelDataSink.poll();
+
+        trainSource.offerAll(trainData2);
+        KMeansModelData actualModelData = modelDataSink.poll();
+
+        KMeansModelData expectedModelData =
+                new KMeansModelData(
+                        new DenseVector[] {
+                            Vectors.dense(10.1, 200.3 / 3), Vectors.dense(-10.2, -200.2 / 3)
+                        });
+
+        Assert.assertEquals(expectedModelData.centroids.length, actualModelData.centroids.length);
+        Arrays.sort(actualModelData.centroids, (o1, o2) -> (int) (o2.values[0] - o1.values[0]));
+        for (int i = 0; i < expectedModelData.centroids.length; i++) {
+            Assert.assertArrayEquals(
+                    expectedModelData.centroids[i].values,
+                    actualModelData.centroids[i].values,
+                    1e-5);
+        }
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel kMeansModel = kMeans.fit(offlineTrainTable);
+
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans(kMeansModel.getModelData())
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+
+        String savePath = tempFolder.newFolder().getAbsolutePath();
+        onlineKMeans.save(savePath);
+        miniCluster.executeJobBlocking(env.getStreamGraph().getJobGraph());
+        OnlineKMeans loadedKMeans = OnlineKMeans.load(env, savePath);
+
+        OnlineKMeansModel onlineModel = loadedKMeans.fit(trainTable);

Review comment:
       Thanks for the detailed explanation. Sounds good.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832772257



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySinkFunction.java
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SinkFunction} implementation that makes all collected records available for tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySinkFunction<T> extends RichSinkFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+
+    public InMemorySinkFunction() {
+        id = UUID.randomUUID();
+        queue = new LinkedBlockingQueue();
+        queueMap.put(id, queue);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        queue = queueMap.get(id);
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        queueMap.remove(id);
+    }
+
+    @Override
+    public void invoke(T value, Context context) {
+        if (!queue.offer(value)) {
+            throw new RuntimeException(
+                    "Failed to offer " + value + " to blocking queue " + id + ".");
+        }
+    }
+
+    public List<T> poll(int num) throws InterruptedException {

Review comment:
       To my understanding, `take()` would not throw exception unless the thread itself is interrupted. Thus this method would block infinitely if the Flink job fails to output expected number of records to this sink, which means relevant tests might never stop. `poll()` at least can unblock the thread and cause some exception afterwards with the returned null value. While in `InMemorySourceFunction` we may add dummy values to the queue to unblock `take()`, we cannot ask all developers to be careful about this in their test cases.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832897095



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))

Review comment:
       Currently `ModelDataGlobalUpdater` takes two inputs, i.e. partition model data delta in this round and global model data from the last round. And it outputs the global model data for this round.
   
   Given that `ModelDataGlobalUpdater` already stores global model data for the last round in its state, would it be simpler to have it take only the partial model data as input?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");

Review comment:
       Could we make `ReadWriteUtils.getDataPath(...)` public and use it here?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {

Review comment:
       nits: could we give it a more meaningful name, e.g. `maybeEmitModelData`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater

Review comment:
       nits: Would it be more intuitive to rename this method as `ModelDataLocalUpdater`?
   
   This corresponds to local batch vs. global batch. Also note that the model data emitted by this operator is actually self-contained and not partial. Partial means `existing only in part; incomplete `.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r833149295



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater

Review comment:
       I agree. I'll make the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r833316743



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))

Review comment:
       After thinking about this more, I think we can actually change the algorithm used in `ModelDataLocalUpdater` and `ModelDataGlobalUpdater ` in such a way that `ModelDataGlobalUpdater` only needs to read from `ModelDataLocalUpdater`'s output and still calculate the right result. We might need to change the type of output emitted by `ModelDataLocalUpdater` .
   
   There is the idea. Currently `ModelDataGlobalUpdater` calculates the weight for the first centroid as `weight_from_last_iteration + sum_of_weights_from_local_updater`. We can change `ModelDataLocalUpdater` to emit `weight_from_last_iteration / parallelism + weights_from_local_batch`. Then `ModelDataGlobalUpdater` can derive the weight for the first centroid as `sum_of_outputs_from_local_updater`, which only depends on the output from `ModelDataLocalUpdater`.
   
   This approach introduces more complexity in the operator. But it could make the job graph simpler and more performant.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r835764832



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySinkFunction.java
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SinkFunction} implementation that makes all collected records available for tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySinkFunction<T> extends RichSinkFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+
+    public InMemorySinkFunction() {
+        id = UUID.randomUUID();
+        queue = new LinkedBlockingQueue();
+        queueMap.put(id, queue);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        queue = queueMap.get(id);
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        queueMap.remove(id);
+    }
+
+    @Override
+    public void invoke(T value, Context context) {
+        if (!queue.offer(value)) {
+            throw new RuntimeException(
+                    "Failed to offer " + value + " to blocking queue " + id + ".");
+        }
+    }
+
+    public List<T> poll(int num) throws InterruptedException {
+        List<T> result = new ArrayList<>();
+        for (int i = 0; i < num; i++) {
+            result.add(poll());
+        }
+        return result;
+    }
+
+    public T poll() throws InterruptedException {
+        return poll(1, TimeUnit.MINUTES);
+    }
+
+    public T poll(long timeout, TimeUnit unit) throws InterruptedException {

Review comment:
       nits: would it be simpler to remove this method and move its content to `T poll()`?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java
##########
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Map;
+import java.util.Optional;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+
+/** A {@link SourceFunction} implementation that can directly receive records from tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySourceFunction<T> extends RichSourceFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<Optional<T>> queue;
+    private volatile boolean isRunning = true;
+
+    public InMemorySourceFunction() {
+        id = UUID.randomUUID();
+        queue = new LinkedBlockingQueue();
+        queueMap.put(id, queue);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        queue = queueMap.get(id);
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        queueMap.remove(id);
+    }
+
+    @Override
+    public void run(SourceContext<T> context) throws InterruptedException {
+        while (isRunning) {
+            Optional<T> maybeValue = queue.take();
+            if (!maybeValue.isPresent()) {
+                continue;

Review comment:
       Given that this can only happen after `cancel()` is invoked, should it be `break`?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -177,11 +177,20 @@ public void testFewerDistinctPointsThanCluster() {
         KMeans kmeans = new KMeans().setK(2);
         KMeansModel model = kmeans.fit(input);
         Table output = model.transform(input)[0];
-        List<Set<DenseVector>> expectedGroups =
-                Collections.singletonList(Collections.singleton(Vectors.dense(0.0, 0.1)));
-        List<Set<DenseVector>> actualGroups =
-                executeAndCollect(output, kmeans.getFeaturesCol(), kmeans.getPredictionCol());
-        assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+
+        try {

Review comment:
       Hmm... it appears that the behavior of the KMeans algorithm is changed.
   
   Spark's `org.apache.spark.mllib.clustering` has a test named `fewer distinct points than clusters`, which would not throw exception when there are less unique points than `K`. Could we keep the same behavior here?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       Would it be useful to explicitly check and throw exception if `parallelism <= batchSize`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansUtils.java
##########
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+
+import java.util.Arrays;
+import java.util.Random;
+
+/** Utility methods for KMeans algorithm. */
+public class KMeansUtils {
+    /**
+     * Generates a Table containing a {@link KMeansModelData} instance with randomly generated
+     * centroids.
+     *
+     * @param env The environment where to create the table.
+     * @param k The number of generated centroids.
+     * @param dim The size of generated centroids.
+     * @param weight The weight of the centroids.
+     * @param seed Random seed.
+     */
+    public static Table generateRandomModelData(

Review comment:
       Would it be simpler to move this method to `KMeansModelData`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansParams.java
##########
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasDecayFactor;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasSeed;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineKMeans}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineKMeansParams<T>
+        extends HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasDecayFactor<T>,
+                HasSeed<T>,
+                KMeansModelParams<T> {
+    Param<Integer> DIM =

Review comment:
       We can remove `DIM` and `INIT_WEIGHT`.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, k, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    /**
+     * Operator that collects a KMeansModelData from each upstream subtask, and outputs the weight
+     * average of collected model data.
+     */
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    /**
+     * An operator that updates KMeans model data locally. It mainly does the following operations.
+     *
+     * <ul>
+     *   <li>Finds the closest centroid id (cluster) of the input points
+     *   <li>Computes the new centroids from the average of input points that belongs to the same
+     *       cluster
+     *   <li>Computes the weighted average of current and new centroids. The weight of a new
+     *       centroid is the number of input points that belong to this cluster. The weight of a
+     *       current centroid is its original weight scaled by $ decayFactor / parallelism $.
+     *   <li>Generates new model data from the weighted average of centroids, and the sum of
+     *       weights.
+     * </ul>
+     */
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchDataState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, int k, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            Preconditions.checkArgument(modelDataRecord.getValue().centroids.length == k);
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);

Review comment:
       Given that `modelDataState.get().iterator().hasNext() == true`, the `orElseThrow()` should not be triggered, right?
   
   Would it be simpler to just call `get()` here?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {

Review comment:
       Currently `fit()` requires users to always explicitly call `setInitialModelData()` before `fit()` is called.
   
   Would it be simpler to also require user to call `setInitialModelData()` before calling `save()`? If not, it would not be clear to whether user can call `fit()` right after an `OnlineKMeansModel` is loaded.
   
   Same for `load()`.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, k, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    /**
+     * Operator that collects a KMeansModelData from each upstream subtask, and outputs the weight
+     * average of collected model data.
+     */
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    /**
+     * An operator that updates KMeans model data locally. It mainly does the following operations.
+     *
+     * <ul>
+     *   <li>Finds the closest centroid id (cluster) of the input points

Review comment:
       nits: `.` is missing at the end of the sentence. Same for the following sentence.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    public static final String MODEL_DATA_VERSION_GAUGE_KEY = "modelDataVersion";
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                                        getK()),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private final int k;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();
+
+        /**
+         * Basic implementation of the model data version with the following rules.
+         *
+         * <ul>
+         *   <li>Negative value is regarded as illegal value.
+         *   <li>Zero value means the version has not been initialized yet.
+         *   <li>Positive value represents valid version.
+         *   <li>A larger value represents a newer version.

Review comment:
       Given that this is `version`, it seems unnecessary to mention `A larger value represents a newer version`. Would it be simpler to remove this sentence?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r835771045



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();

Review comment:
       The long term solution is to `read first model data from unbounded model data stream before processing the first predict data`. If we can not achieve this goal, the OnlineKMeansModel won't be reliable for production anyway usage due to OOM. So it does not seem helpful to checkpoint the bufferedPoints here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834915786



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();
+
+        // TODO: replace this simple implementation of model data version with the formal API to
+        // track model version after its design is settled.
+        private int modelDataVersion;

Review comment:
       Got it. I'll add the default value and javadocs about meaning and rules about this value.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r833839951



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))

Review comment:
       I agree. I'll make the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#issuecomment-1076988591


   Thanks for the comments. I have updated the PR according to the comments.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834912447



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
##########
@@ -159,68 +161,78 @@ public IterationBodyResult process(
                                             DenseVectorTypeInfo.INSTANCE),
                                     new SelectNearestCentroidOperator(distanceMeasure));
 
-            AllWindowFunction<DenseVector, DenseVector[], TimeWindow> toList =
-                    new AllWindowFunction<DenseVector, DenseVector[], TimeWindow>() {
-                        @Override
-                        public void apply(
-                                TimeWindow timeWindow,
-                                Iterable<DenseVector> iterable,
-                                Collector<DenseVector[]> out) {
-                            List<DenseVector> centroids = IteratorUtils.toList(iterable.iterator());
-                            out.collect(centroids.toArray(new DenseVector[0]));
-                        }
-                    };
-
             PerRoundSubBody perRoundSubBody =
                     new PerRoundSubBody() {
                         @Override
                         public DataStreamList process(DataStreamList inputs) {
                             DataStream<Tuple2<Integer, DenseVector>> centroidIdAndPoints =
                                     inputs.get(0);
-                            DataStream<DenseVector[]> newCentroids =
+                            DataStream<KMeansModelData> modelDataStream =
                                     centroidIdAndPoints
                                             .map(new CountAppender())
                                             .keyBy(t -> t.f0)
                                             .window(EndOfStreamWindows.get())
                                             .reduce(new CentroidAccumulator())
                                             .map(new CentroidAverager())
                                             .windowAll(EndOfStreamWindows.get())
-                                            .apply(toList);
-                            return DataStreamList.of(newCentroids);
+                                            .apply(new ModelDataGenerator());
+                            return DataStreamList.of(modelDataStream);
                         }
                     };
-
-            DataStream<DenseVector[]> newCentroids =
+            DataStream<KMeansModelData> newModelData =
                     IterationBody.forEachRound(
                                     DataStreamList.of(centroidIdAndPoints), perRoundSubBody)
                             .get(0);
-            DataStream<DenseVector[]> finalCentroids =
-                    newCentroids.flatMap(new ForwardInputsOfLastRound<>());
+
+            DataStream<DenseVector[]> newCentroids =
+                    newModelData.map(x -> x.centroids).setParallelism(1);
+
+            DataStream<KMeansModelData> finalModelData =
+                    newModelData.flatMap(new ForwardInputsOfLastRound<>());
 
             return new IterationBodyResult(
                     DataStreamList.of(newCentroids),
-                    DataStreamList.of(finalCentroids),
+                    DataStreamList.of(finalModelData),
                     terminationCriteria);
         }
     }
 
+    private static class ModelDataGenerator
+            implements AllWindowFunction<Tuple2<DenseVector, Double>, KMeansModelData, TimeWindow> {
+        @Override
+        public void apply(
+                TimeWindow timeWindow,
+                Iterable<Tuple2<DenseVector, Double>> iterable,
+                Collector<KMeansModelData> collector) {
+            List<Tuple2<DenseVector, Double>> centroidsAndWeights =

Review comment:
       Could we pass `k` (number of clusters) as a parameter for `ModelDataGenerator`, such that we can avoid creating a list of centroids? This could be more memory-efficient if `k` is large.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834918943



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();

Review comment:
       OK. I'll make the change.
   
   One small concern is that Flink ML should depend on the Table API, but now it seems that Flink ML also depends on DataStream API.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834996396



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelParams.java
##########
@@ -21,27 +21,11 @@
 import org.apache.flink.ml.common.param.HasDistanceMeasure;
 import org.apache.flink.ml.common.param.HasFeaturesCol;
 import org.apache.flink.ml.common.param.HasPredictionCol;
-import org.apache.flink.ml.param.IntParam;
-import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.param.ParamValidators;
 
 /**
- * Params of {@link KMeansModel}.
+ * Params of {@link KMeansModel} and {@link OnlineKMeansModel}.
  *
  * @param <T> The class type of this instance.
  */
 public interface KMeansModelParams<T>
-        extends HasDistanceMeasure<T>, HasFeaturesCol<T>, HasPredictionCol<T> {
-
-    Param<Integer> K =

Review comment:
       Sounds good. Let's do the check in both `KMeansModel` and `OnlineKMeansModel`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834995617



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
##########
@@ -47,8 +47,16 @@
 
     public DenseVector[] centroids;
 
+    public DenseVector weights;
+
+    public KMeansModelData(DenseVector[] centroids, DenseVector weights) {
+        this.centroids = centroids;
+        this.weights = weights;
+    }
+
     public KMeansModelData(DenseVector[] centroids) {
         this.centroids = centroids;
+        this.weights = new DenseVector(centroids.length);

Review comment:
       If this model data is only guaranteed to be used only by `OnlineKMeansModel`, I agree we can do that. 
   
   For the code that constructs a KMeansModelData instance, if it is not guaranteed that this instance is only used by `OnlineKMeansModel`, it seems simpler to make this model data `self-contained` with valid weights.
   
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834912599



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
##########
@@ -159,68 +161,78 @@ public IterationBodyResult process(
                                             DenseVectorTypeInfo.INSTANCE),
                                     new SelectNearestCentroidOperator(distanceMeasure));
 
-            AllWindowFunction<DenseVector, DenseVector[], TimeWindow> toList =
-                    new AllWindowFunction<DenseVector, DenseVector[], TimeWindow>() {
-                        @Override
-                        public void apply(
-                                TimeWindow timeWindow,
-                                Iterable<DenseVector> iterable,
-                                Collector<DenseVector[]> out) {
-                            List<DenseVector> centroids = IteratorUtils.toList(iterable.iterator());
-                            out.collect(centroids.toArray(new DenseVector[0]));
-                        }
-                    };
-
             PerRoundSubBody perRoundSubBody =
                     new PerRoundSubBody() {
                         @Override
                         public DataStreamList process(DataStreamList inputs) {
                             DataStream<Tuple2<Integer, DenseVector>> centroidIdAndPoints =
                                     inputs.get(0);
-                            DataStream<DenseVector[]> newCentroids =
+                            DataStream<KMeansModelData> modelDataStream =
                                     centroidIdAndPoints
                                             .map(new CountAppender())
                                             .keyBy(t -> t.f0)
                                             .window(EndOfStreamWindows.get())
                                             .reduce(new CentroidAccumulator())
                                             .map(new CentroidAverager())
                                             .windowAll(EndOfStreamWindows.get())
-                                            .apply(toList);
-                            return DataStreamList.of(newCentroids);
+                                            .apply(new ModelDataGenerator());
+                            return DataStreamList.of(modelDataStream);
                         }
                     };
-
-            DataStream<DenseVector[]> newCentroids =
+            DataStream<KMeansModelData> newModelData =
                     IterationBody.forEachRound(
                                     DataStreamList.of(centroidIdAndPoints), perRoundSubBody)
                             .get(0);
-            DataStream<DenseVector[]> finalCentroids =
-                    newCentroids.flatMap(new ForwardInputsOfLastRound<>());
+
+            DataStream<DenseVector[]> newCentroids =
+                    newModelData.map(x -> x.centroids).setParallelism(1);
+
+            DataStream<KMeansModelData> finalModelData =
+                    newModelData.flatMap(new ForwardInputsOfLastRound<>());
 
             return new IterationBodyResult(
                     DataStreamList.of(newCentroids),
-                    DataStreamList.of(finalCentroids),
+                    DataStreamList.of(finalModelData),
                     terminationCriteria);
         }
     }
 
+    private static class ModelDataGenerator
+            implements AllWindowFunction<Tuple2<DenseVector, Double>, KMeansModelData, TimeWindow> {
+        @Override
+        public void apply(
+                TimeWindow timeWindow,
+                Iterable<Tuple2<DenseVector, Double>> iterable,
+                Collector<KMeansModelData> collector) {
+            List<Tuple2<DenseVector, Double>> centroidsAndWeights =

Review comment:
       Could we pass `k` (number of clusters) as a parameter for `ModelDataGenerator`, such that we can avoid creating a list of centroids? This could be more memory-efficient if `k` is large.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
##########
@@ -47,8 +47,16 @@
 
     public DenseVector[] centroids;
 
+    public DenseVector weights;

Review comment:
       nits: Could we add java docs to explain why we add `weights` here?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;

Review comment:
       nits: Is `localBatchDataState` a better name?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    // TODO: creates static mini cluster once for whole test class after dependency upgrades to
+    // Flink 1.15.
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table onlineTrainTable;
+    private Table onlinePredictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        onlineTrainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        onlinePredictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void transformAndOutputData(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(onlinePredictTable)[0];
+        tEnv.toDataStream(outputTable).addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        KMeansModelData.getModelDataStream(modelDataTable).addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() throws InterruptedException {
+        while (reporter.findMetrics(OnlineKMeansModel.MODEL_DATA_VERSION_GAUGE_KEY).size()
+                < defaultParallelism) {
+            Thread.sleep(100);
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate() throws InterruptedException {
+        do {
+            int tmpModelDataVersion =
+                    reporter.findMetrics(OnlineKMeansModel.MODEL_DATA_VERSION_GAUGE_KEY).values()
+                            .stream()
+                            .map(x -> Integer.parseInt(((Gauge<String>) x).getValue()))
+                            .min(Integer::compareTo)
+                            .get();
+            if (tmpModelDataVersion == currentModelDataVersion) {
+                Thread.sleep(100);
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        predictSource.addAll(OnlineKMeansTest.predictData);
+        List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        OnlineKMeans onlineKMeans = new OnlineKMeans();
+        Assert.assertEquals("features", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure());
+        Assert.assertEquals(32, onlineKMeans.getGlobalBatchSize());
+        Assert.assertEquals(0., onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), onlineKMeans.getSeed());
+
+        onlineKMeans
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setGlobalBatchSize(5)
+                .setDecayFactor(0.25)
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure());
+        Assert.assertEquals(5, onlineKMeans.getGlobalBatchSize());
+        Assert.assertEquals(0.25, onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals(100, onlineKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans()
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setGlobalBatchSize(6)
+                        .setRandomCentroids(2, 0.);

Review comment:
       nits: shall we change `0.` to `0.0`? Just to ensure that the code style is consistent.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasBatchStrategy.java
##########
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared batch strategy param. */
+public interface HasBatchStrategy<T> extends WithParams<T> {
+    String COUNT_STRATEGY = "count";

Review comment:
       How about explain a bit about what is `count`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay

Review comment:
       nits: decay factor -> the decay factor

##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java
##########
@@ -34,6 +35,7 @@ public static EuclideanDistanceMeasure getInstance() {
 
     @Override
     public double distance(Vector v1, Vector v2) {
+        Preconditions.checkArgument(v1.size() == v2.size());
         double squaredDistance = 0.0;
 
         for (int i = 0; i < v1.size(); i++) {

Review comment:
       nit: do you think using BLAS here is more efficient? It is okay to leave it as it is, since it is not part of this PR.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);

Review comment:
       nits: is `onlineKmeans` a better name here?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =

Review comment:
       One quick question: should we make it `static` or not?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =

Review comment:
       nits: Could `finalModelData` be renamed to `onlineModelData`? It is not really `the final model data`.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();

Review comment:
       Would the following code be more readable?
   
   ```
   int div = values.length / downStreamParallelism;
   int mod = values.length % downStreamParallelism;
   int offset = 0;
   for (int i = 0; i < downStreamParallelism; i ++) {
           int size = i >= mod ? div: div + 1;
           collector.collect(Arrays.copyOfRange(values, offset, offset + size));
   }
   ```

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;

Review comment:
       nit: this line could be removed.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();

Review comment:
       One stupid question: Is the `bufferedPoints` to be checkpointed when doing snapshot?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];

Review comment:
       nit: this could be a BLAS operation.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(

Review comment:
       If the users do not provide an init model, could we follow `Kmeans` and randomly initialize the model data? 
   
   Exposing this method to end users seems a bit unnecessary to me.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {

Review comment:
       +1

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    // TODO: creates static mini cluster once for whole test class after dependency upgrades to
+    // Flink 1.15.
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table onlineTrainTable;
+    private Table onlinePredictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();

Review comment:
       nits: `schema` here seems uncessary here or can we simply replace `f0` with `features` to simplify the code?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829776167



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDecayFactor.java
##########
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared decay factor param. */
+public interface HasDecayFactor<T> extends WithParams<T> {
+    Param<Double> DECAY_FACTOR =
+            new DoubleParam(
+                    "decayFactor",
+                    "The forgetfulness of the previous centroids.",
+                    0.,
+                    ParamValidators.gtEq(0));

Review comment:
       Spark has not set an upper bound limit to this parameter. It only ensures that decayFactor is nonnegative, as follows.
   ```scala
     def setDecayFactor(a: Double): this.type = {
       require(a >= 0,
         s"Decay factor must be nonnegative but got ${a}")
       this.decayFactor = a
       this
     }
   ```
   I also understand your concern. If this value is larger than 1, then it cannot represent so-called "forgetfulness", as the weight of init model data is always strengthened... Maybe we can make it have to be smaller than 1, and remove this limit when we found relative use cases. What do you think?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#issuecomment-1072229314


   Thanks for the comments @lindong28 . I have updated the PR and left responses to the comments.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r822304676



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {
+    private static final Map<String, Object> map = new HashMap<>();
+
+    private static long counter = 0;
+
+    /**
+     * Returns a prefix string to make sure key-value pairs created in different test cases would
+     * not have the same key.
+     */
+    public static synchronized String createNonDuplicatePrefix() {

Review comment:
       Could we simplify the `MockKVStore`'s functionality by removing this method and simply let the caller generate a random int64 value?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockMessageQueues.java
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A class that manages global message queues used in unit tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class MockMessageQueues {
+    private static final Map<String, BlockingQueue> queueMap = new HashMap<>();
+    private static long counter = 0;
+
+    public static synchronized String createMessageQueue() {
+        String id = String.valueOf(counter);
+        queueMap.put(id, new LinkedBlockingQueue<>());
+        counter++;
+        return id;
+    }
+
+    @SafeVarargs
+    public static <T> void offerAll(String id, T... values) throws InterruptedException {
+        for (T value : values) {
+            offer(id, value);
+        }
+    }
+
+    public static <T> void offer(String id, T value) throws InterruptedException {
+        offer(id, value, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> void offer(String id, T value, long timeout, TimeUnit unit)
+            throws InterruptedException {
+        boolean success = queueMap.get(id).offer(value, timeout, unit);
+        if (!success) {
+            throw new RuntimeException(
+                    "Failed to offer " + value + " to blocking queue " + id + ".");
+        }
+    }
+
+    public static <T> List<T> poll(String id, int num) throws InterruptedException {
+        List<T> result = new ArrayList<>();
+        for (int i = 0; i < num; i++) {
+            result.add(poll(id));
+        }
+        return result;
+    }
+
+    public static <T> T poll(String id) throws InterruptedException {
+        return poll(id, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> T poll(String id, long timeout, TimeUnit unit) throws InterruptedException {
+        T value = (T) queueMap.get(id).poll(timeout, unit);
+        if (value == null) {
+            throw new RuntimeException("Failed to poll next value from blocking queue " + id + ".");
+        }
+        return value;
+    }
+
+    public static void deleteBlockingQueue(String id) {

Review comment:
       Given that we always remove all queues when this method is called, would it be simpler to replace this method with `clear()` which removes all queues?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeansModel.java
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * StreamingKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link StreamingKMeans}.
+ */
+public class StreamingKMeansModel
+        implements Model<StreamingKMeansModel>, KMeansModelParams<StreamingKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public StreamingKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public StreamingKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> cache = new ArrayList<>();

Review comment:
       It is not clear what this `cache` represents. And usually cache is used for improving performance.
   
   How about renaming this variable to `bufferedPoints`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model

Review comment:
       Would it be useful to provide more detail of the algorithm here, similar to that of Spark's StreamingKMeans Javadoc?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockMessageQueues.java
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A class that manages global message queues used in unit tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class MockMessageQueues {

Review comment:
       Would it be more self-explanatory to rename it `GlobalBlockingQueues`?
   
   Given that this class represents real queues, it is not clear what `mock` means here.

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {

Review comment:
       It looks like `MockKVStore` is only used to store the metrics reported by `TestMetricReporter`.
   
   Would it be simpler to remove `MockKVStore` and let `TestMetricReporter` expose `static <T> T get(String key)`?
   
   And if `TestMetricReporter` is expected to be singleton, would it be more self-explanatory to rename it `SingletonInMemoryMetricReporter`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDecayFactor.java
##########
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared decay factor param. */
+public interface HasDecayFactor<T> extends WithParams<T> {
+    Param<Double> DECAY_FACTOR =
+            new DoubleParam(
+                    "decayFactor",
+                    "The forgetfulness of the previous centroids.",
+                    0.,
+                    ParamValidators.gtEq(0));

Review comment:
       Does this value need to be smaller than 1?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/StreamingKMeansTest.java
##########
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeans;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.MockKVStore;
+import org.apache.flink.ml.util.MockMessageQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link StreamingKMeans} and {@link StreamingKMeansModel}. */
+public class StreamingKMeansTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private String trainId;
+    private String predictId;
+    private String outputId;
+    private String modelDataId;
+    private String metricReporterPrefix;
+    private String modelDataVersionGaugeKey;
+    private String currentModelDataVersion;
+    private List<String> queueIds;
+    private List<String> kvStoreKeys;
+    private List<JobClient> clients;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() {
+        metricReporterPrefix = MockKVStore.createNonDuplicatePrefix();
+        kvStoreKeys = new ArrayList<>();
+        kvStoreKeys.add(metricReporterPrefix);
+
+        modelDataVersionGaugeKey =
+                TestMetricReporter.getKey(metricReporterPrefix, "modelDataVersion");
+
+        currentModelDataVersion = "0";
+
+        trainId = MockMessageQueues.createMessageQueue();
+        predictId = MockMessageQueues.createMessageQueue();
+        outputId = MockMessageQueues.createMessageQueue();
+        modelDataId = MockMessageQueues.createMessageQueue();
+        queueIds = new ArrayList<>();
+        queueIds.addAll(Arrays.asList(trainId, predictId, outputId, modelDataId));
+
+        clients = new ArrayList<>();
+
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        config.setString("metrics.reporters", "test_reporter");
+        config.setString(
+                "metrics.reporter.test_reporter.class", TestMetricReporter.class.getName());
+        config.setString("metrics.reporter.test_reporter.interval", "100 MILLISECONDS");
+        config.setString("metrics.reporter.test_reporter.prefix", metricReporterPrefix);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(trainId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(predictId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() {
+        for (JobClient client : clients) {
+            try {
+                client.cancel();
+            } catch (IllegalStateException e) {
+                if (!e.getMessage()
+                        .equals("MiniCluster is not yet running or has already been shut down.")) {
+                    throw e;
+                }
+            }
+        }
+        clients.clear();
+
+        for (String queueId : queueIds) {
+            MockMessageQueues.deleteBlockingQueue(queueId);
+        }
+        queueIds.clear();
+
+        for (String key : kvStoreKeys) {
+            MockKVStore.remove(key);
+        }
+        kvStoreKeys.clear();
+    }
+
+    /** Adds sinks for StreamingKMeansModel's transform output and model data. */
+    private void configModelSink(StreamingKMeansModel streamingModel) {

Review comment:
       Could we make this method name more self-explanatory by indicating that it involves the `transform(...)` operation?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);

Review comment:
       Is there a way to enable parallel processing of input streams, instead of limiting its parallelism to 1?
   
   If it is possible that hard to implement, I think it is acceptable to use that parallelism 1 in this PR. If so, since the parallelism is related to the performance of this class, could we explain it in the Java doc of this class?
   
   
   
   

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);

Review comment:
       Does this call guarantee that all values from `initModelDataStream` comes before `finalModelDataStream` in the output stream?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/StreamingKMeansTest.java
##########
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeans;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.MockKVStore;
+import org.apache.flink.ml.util.MockMessageQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link StreamingKMeans} and {@link StreamingKMeansModel}. */
+public class StreamingKMeansTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private String trainId;
+    private String predictId;
+    private String outputId;
+    private String modelDataId;
+    private String metricReporterPrefix;
+    private String modelDataVersionGaugeKey;
+    private String currentModelDataVersion;
+    private List<String> queueIds;
+    private List<String> kvStoreKeys;
+    private List<JobClient> clients;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() {
+        metricReporterPrefix = MockKVStore.createNonDuplicatePrefix();
+        kvStoreKeys = new ArrayList<>();
+        kvStoreKeys.add(metricReporterPrefix);
+
+        modelDataVersionGaugeKey =
+                TestMetricReporter.getKey(metricReporterPrefix, "modelDataVersion");
+
+        currentModelDataVersion = "0";
+
+        trainId = MockMessageQueues.createMessageQueue();
+        predictId = MockMessageQueues.createMessageQueue();
+        outputId = MockMessageQueues.createMessageQueue();
+        modelDataId = MockMessageQueues.createMessageQueue();
+        queueIds = new ArrayList<>();
+        queueIds.addAll(Arrays.asList(trainId, predictId, outputId, modelDataId));
+
+        clients = new ArrayList<>();
+
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        config.setString("metrics.reporters", "test_reporter");
+        config.setString(
+                "metrics.reporter.test_reporter.class", TestMetricReporter.class.getName());
+        config.setString("metrics.reporter.test_reporter.interval", "100 MILLISECONDS");
+        config.setString("metrics.reporter.test_reporter.prefix", metricReporterPrefix);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(trainId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(predictId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() {
+        for (JobClient client : clients) {
+            try {
+                client.cancel();
+            } catch (IllegalStateException e) {
+                if (!e.getMessage()
+                        .equals("MiniCluster is not yet running or has already been shut down.")) {
+                    throw e;
+                }
+            }
+        }
+        clients.clear();
+
+        for (String queueId : queueIds) {
+            MockMessageQueues.deleteBlockingQueue(queueId);
+        }
+        queueIds.clear();
+
+        for (String key : kvStoreKeys) {
+            MockKVStore.remove(key);
+        }
+        kvStoreKeys.clear();
+    }
+
+    /** Adds sinks for StreamingKMeansModel's transform output and model data. */
+    private void configModelSink(StreamingKMeansModel streamingModel) {
+        Table outputTable = streamingModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(new MockSinkFunction<>(outputId));
+
+        Table modelDataTable = streamingModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(new MockSinkFunction<>(modelDataId));
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (!MockKVStore.containsKey(modelDataVersionGaugeKey)) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion = MockKVStore.get(modelDataVersionGaugeKey);
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        MockMessageQueues.offerAll(predictId, StreamingKMeansTest.predictData);
+        List<Row> rawResult =
+                MockMessageQueues.poll(outputId, StreamingKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        StreamingKMeans streamingKMeans = new StreamingKMeans();
+        Assert.assertEquals("features", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, streamingKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(2, streamingKMeans.getK());
+        Assert.assertEquals(1, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(1, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0., streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(StreamingKMeans.class.getName().hashCode(), streamingKMeans.getSeed());
+
+        streamingKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(3, streamingKMeans.getK());
+        Assert.assertEquals(5, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(5, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0.25, streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", streamingKMeans.getInitMode());
+        Assert.assertEquals(100, streamingKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithOfflineKMeans() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setDecayFactor(100.0)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        KMeans offlineKMeans =

Review comment:
       Would it be simpler to name variable after its class name?
   
   Currently we use `KMeans` to refer to offline algorithm and `StreamingKMeans` to refer to online algorithm. Thus we already expect users to understand that `KMeans` refers to the offline algorithm.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());

Review comment:
       If initMode is "random", should we assert that `initModelDataTable  == null`. And if initMode is `direct`, should we assert that `initModelDataTable != null`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans

Review comment:
       Should this be renamed `OnlineKMeans`?
   
   Same for other class names.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();

Review comment:
       Is it correct to clear the modelDataState here?
   
   By doing so, it means that a model data can only be used to process input points received before it, but not input points received after it, right?
   
   I think it is not guaranteed that there is exactly one batch of data between two model data records.

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -207,25 +211,21 @@ public void testSaveLoadAndPredict() throws Exception {
         KMeansModel loadedModel =
                 StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
         Table output = loadedModel.transform(dataTable)[0];
-        assertEquals(
-                Collections.singletonList("centroids"),
-                loadedModel.getModelData()[0].getResolvedSchema().getColumnNames());
         assertEquals(
                 Arrays.asList("features", "prediction"),
                 output.getResolvedSchema().getColumnNames());
 
+        List<Row> results = IteratorUtils.toList(output.execute().collect());
         List<Set<DenseVector>> actualGroups =
-                executeAndCollect(output, kmeans.getFeaturesCol(), kmeans.getPredictionCol());
+                groupFeaturesByPrediction(
+                        results, kmeans.getFeaturesCol(), kmeans.getPredictionCol());
         assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
     }
 
     @Test
     public void testGetModelData() throws Exception {
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
         KMeansModel model = kmeans.fit(dataTable);
-        assertEquals(
-                Collections.singletonList("centroids"),

Review comment:
       Hmm.. why do we remove this check?
   
   After removing this, is there a test that verifies the model data table schema is correct?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();
+
+            List<DenseVector> weightsList = IteratorUtils.toList(weightsState.get().iterator());
+            if (weightsList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + weightsList.size()
+                                + " list of weights in this round");
+            }
+            DenseVector weights = weightsList.get(0);
+            weightsState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int dims = centroids[0].size();
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dims);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dims; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor, weights);
+            for (int i = 0; i < k; i++) {
+                DenseVector centroid = centroids[i];
+
+                double updatedWeight = weights.values[i] + counts[i];

Review comment:
       Would it be simpler to replace `updatedWeight` with `weights.values[i]`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();
+
+            List<DenseVector> weightsList = IteratorUtils.toList(weightsState.get().iterator());
+            if (weightsList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + weightsList.size()
+                                + " list of weights in this round");
+            }
+            DenseVector weights = weightsList.get(0);
+            weightsState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);

Review comment:
       Hmm... why do we add back these points? Would it be simpler to just process all buffered points in this method call?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/StreamingKMeansTest.java
##########
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeans;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.MockKVStore;
+import org.apache.flink.ml.util.MockMessageQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link StreamingKMeans} and {@link StreamingKMeansModel}. */
+public class StreamingKMeansTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private String trainId;
+    private String predictId;
+    private String outputId;
+    private String modelDataId;
+    private String metricReporterPrefix;
+    private String modelDataVersionGaugeKey;
+    private String currentModelDataVersion;
+    private List<String> queueIds;
+    private List<String> kvStoreKeys;
+    private List<JobClient> clients;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() {
+        metricReporterPrefix = MockKVStore.createNonDuplicatePrefix();
+        kvStoreKeys = new ArrayList<>();
+        kvStoreKeys.add(metricReporterPrefix);
+
+        modelDataVersionGaugeKey =
+                TestMetricReporter.getKey(metricReporterPrefix, "modelDataVersion");
+
+        currentModelDataVersion = "0";
+
+        trainId = MockMessageQueues.createMessageQueue();
+        predictId = MockMessageQueues.createMessageQueue();
+        outputId = MockMessageQueues.createMessageQueue();
+        modelDataId = MockMessageQueues.createMessageQueue();
+        queueIds = new ArrayList<>();
+        queueIds.addAll(Arrays.asList(trainId, predictId, outputId, modelDataId));
+
+        clients = new ArrayList<>();
+
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        config.setString("metrics.reporters", "test_reporter");
+        config.setString(
+                "metrics.reporter.test_reporter.class", TestMetricReporter.class.getName());
+        config.setString("metrics.reporter.test_reporter.interval", "100 MILLISECONDS");
+        config.setString("metrics.reporter.test_reporter.prefix", metricReporterPrefix);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(trainId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(predictId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() {
+        for (JobClient client : clients) {
+            try {
+                client.cancel();
+            } catch (IllegalStateException e) {
+                if (!e.getMessage()
+                        .equals("MiniCluster is not yet running or has already been shut down.")) {
+                    throw e;
+                }
+            }
+        }
+        clients.clear();
+
+        for (String queueId : queueIds) {
+            MockMessageQueues.deleteBlockingQueue(queueId);
+        }
+        queueIds.clear();
+
+        for (String key : kvStoreKeys) {
+            MockKVStore.remove(key);
+        }
+        kvStoreKeys.clear();
+    }
+
+    /** Adds sinks for StreamingKMeansModel's transform output and model data. */
+    private void configModelSink(StreamingKMeansModel streamingModel) {
+        Table outputTable = streamingModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(new MockSinkFunction<>(outputId));
+
+        Table modelDataTable = streamingModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(new MockSinkFunction<>(modelDataId));
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (!MockKVStore.containsKey(modelDataVersionGaugeKey)) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion = MockKVStore.get(modelDataVersionGaugeKey);
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        MockMessageQueues.offerAll(predictId, StreamingKMeansTest.predictData);
+        List<Row> rawResult =
+                MockMessageQueues.poll(outputId, StreamingKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        StreamingKMeans streamingKMeans = new StreamingKMeans();
+        Assert.assertEquals("features", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, streamingKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(2, streamingKMeans.getK());
+        Assert.assertEquals(1, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(1, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0., streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(StreamingKMeans.class.getName().hashCode(), streamingKMeans.getSeed());
+
+        streamingKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(3, streamingKMeans.getK());
+        Assert.assertEquals(5, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(5, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0.25, streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", streamingKMeans.getInitMode());
+        Assert.assertEquals(100, streamingKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithOfflineKMeans() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setDecayFactor(100.0)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setDims(2)
+                        .setBatchSize(6)
+                        .setInitWeights(new Double[] {0., 0.});
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeans loadedKMeans =
+                StageTestUtils.saveAndReload(
+                        env, streamingKMeans, tempFolder.newFolder().getAbsolutePath());
+
+        StreamingKMeansModel streamingModel = loadedKMeans.fit(trainTable);
+
+        String modelDataPassId = MockMessageQueues.createMessageQueue();
+        queueIds.add(modelDataPassId);
+
+        String modelSavePath = tempFolder.newFolder().getAbsolutePath();
+        streamingModel.save(modelSavePath);
+        KMeansModelData.getModelDataStream(streamingModel.getModelData()[0])
+                .addSink(new MockSinkFunction<>(modelDataPassId));
+        clients.add(env.executeAsync());

Review comment:
       Hmm... what is the issue with save() if we use `executeAsync()`?
   
   `KafkaConsumerTestBase::runMultipleSourcesOnePartitionExactlyOnceTest()` also uses unbounded source and verifies that the expected sequence of values are generated. Maybe follow that practice as example?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {

Review comment:
       Maybe use `OperatorStateUtils.getUniqueElement` for simplicity?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;

Review comment:
       Since OnlineKMeans is a different algorithm from KMeans, should we move its files to a dedicated folder, e.g. `onlinekmeans`?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/StreamingKMeansTest.java
##########
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeans;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.MockKVStore;
+import org.apache.flink.ml.util.MockMessageQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link StreamingKMeans} and {@link StreamingKMeansModel}. */
+public class StreamingKMeansTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private String trainId;
+    private String predictId;
+    private String outputId;
+    private String modelDataId;
+    private String metricReporterPrefix;
+    private String modelDataVersionGaugeKey;
+    private String currentModelDataVersion;
+    private List<String> queueIds;
+    private List<String> kvStoreKeys;
+    private List<JobClient> clients;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() {
+        metricReporterPrefix = MockKVStore.createNonDuplicatePrefix();
+        kvStoreKeys = new ArrayList<>();
+        kvStoreKeys.add(metricReporterPrefix);
+
+        modelDataVersionGaugeKey =
+                TestMetricReporter.getKey(metricReporterPrefix, "modelDataVersion");
+
+        currentModelDataVersion = "0";
+
+        trainId = MockMessageQueues.createMessageQueue();
+        predictId = MockMessageQueues.createMessageQueue();
+        outputId = MockMessageQueues.createMessageQueue();
+        modelDataId = MockMessageQueues.createMessageQueue();
+        queueIds = new ArrayList<>();
+        queueIds.addAll(Arrays.asList(trainId, predictId, outputId, modelDataId));
+
+        clients = new ArrayList<>();
+
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        config.setString("metrics.reporters", "test_reporter");
+        config.setString(
+                "metrics.reporter.test_reporter.class", TestMetricReporter.class.getName());
+        config.setString("metrics.reporter.test_reporter.interval", "100 MILLISECONDS");
+        config.setString("metrics.reporter.test_reporter.prefix", metricReporterPrefix);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(trainId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(predictId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() {
+        for (JobClient client : clients) {
+            try {
+                client.cancel();
+            } catch (IllegalStateException e) {
+                if (!e.getMessage()
+                        .equals("MiniCluster is not yet running or has already been shut down.")) {
+                    throw e;
+                }
+            }
+        }
+        clients.clear();
+
+        for (String queueId : queueIds) {
+            MockMessageQueues.deleteBlockingQueue(queueId);
+        }
+        queueIds.clear();
+
+        for (String key : kvStoreKeys) {
+            MockKVStore.remove(key);
+        }
+        kvStoreKeys.clear();
+    }
+
+    /** Adds sinks for StreamingKMeansModel's transform output and model data. */
+    private void configModelSink(StreamingKMeansModel streamingModel) {
+        Table outputTable = streamingModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(new MockSinkFunction<>(outputId));
+
+        Table modelDataTable = streamingModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(new MockSinkFunction<>(modelDataId));
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (!MockKVStore.containsKey(modelDataVersionGaugeKey)) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion = MockKVStore.get(modelDataVersionGaugeKey);
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        MockMessageQueues.offerAll(predictId, StreamingKMeansTest.predictData);
+        List<Row> rawResult =
+                MockMessageQueues.poll(outputId, StreamingKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        StreamingKMeans streamingKMeans = new StreamingKMeans();
+        Assert.assertEquals("features", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, streamingKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(2, streamingKMeans.getK());
+        Assert.assertEquals(1, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(1, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0., streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(StreamingKMeans.class.getName().hashCode(), streamingKMeans.getSeed());
+
+        streamingKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(3, streamingKMeans.getK());
+        Assert.assertEquals(5, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(5, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0.25, streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", streamingKMeans.getInitMode());
+        Assert.assertEquals(100, streamingKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithOfflineKMeans() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setDecayFactor(100.0)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setDims(2)
+                        .setBatchSize(6)
+                        .setInitWeights(new Double[] {0., 0.});
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeans loadedKMeans =
+                StageTestUtils.saveAndReload(
+                        env, streamingKMeans, tempFolder.newFolder().getAbsolutePath());
+
+        StreamingKMeansModel streamingModel = loadedKMeans.fit(trainTable);
+
+        String modelDataPassId = MockMessageQueues.createMessageQueue();
+        queueIds.add(modelDataPassId);
+
+        String modelSavePath = tempFolder.newFolder().getAbsolutePath();
+        streamingModel.save(modelSavePath);
+        KMeansModelData.getModelDataStream(streamingModel.getModelData()[0])
+                .addSink(new MockSinkFunction<>(modelDataPassId));
+        clients.add(env.executeAsync());

Review comment:
       Sounds good.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model

Review comment:
       Sounds good.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r831986897



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);

Review comment:
       According to offline discussion, I'll add a single-threaded operator for mini batch distribution, and have the rest training process still working in parallel.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832775815



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SourceFunction} implementation that can directly receive records from tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySourceFunction<T> extends RichSourceFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+    private boolean isRunning = true;
+
+    public InMemorySourceFunction() {
+        id = UUID.randomUUID();
+        queue = new LinkedBlockingQueue();
+        queueMap.put(id, queue);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        queue = queueMap.get(id);
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        queueMap.remove(id);
+    }
+
+    @Override
+    public void run(SourceContext<T> context) {
+        while (isRunning) {
+            T value = queue.poll();
+            if (value == null) {
+                Thread.yield();

Review comment:
       In order to add a dummy value, users would have to input a `T value` or `Class<T> clazz`, and in the latter choice `T` must have a no-arg constructor, which I think is not simple enough. As the purpose is to avoid busy loop, how about using `poll(long, TimeUnit)` here?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#issuecomment-1075872552


   Thanks for the comments. I have updated the PR according to these comments.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832001871



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;

Review comment:
       I think keeping `OnlineKMeans` in `kmeans` package is more suitable. Because
   - I hardly feel that `OnlineKMeans` and `KMeans` are different algorithms. They are the same algorithm with different external condition. 
   - In implementation, `OnlineKMeans` and `KMeans` have exclusively shared some infra codes, like `KMeansModelData`, `KMeansParams` and `findClosestCentroidId`. If they are put into different packages, I am worried that there will be high cohesion between `kmeans` and `onlinekmeans` packages.
   
   The reasons above are not definitive and I would still be glad to make the change if we have other more important reasons.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r833144569



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");

Review comment:
       OK. I'll make the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 merged pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 merged pull request #70:
URL: https://github.com/apache/flink-ml/pull/70


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834933083



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {

Review comment:
       Suppose all options can meet our use-case (e.g. allow user to get the information they need), then we need to choose the option that is more usable, simpler and easier to understand.
   
   The issue with not exposing getK() is that `K` seems like an important attribute of the clustering algorithm and users will want to get this information. Maybe @zhipeng93 can comment on this.
   
   The issue with using `setInitMode().setDim().setWeight() ` is that user needs to call much 3 methods instead of 1 method to set random centroids and understand that setDim() and setWeight() depends on initMode. This could make the algorithm harder to use. What do you think?
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834931138



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelParams.java
##########
@@ -21,27 +21,11 @@
 import org.apache.flink.ml.common.param.HasDistanceMeasure;
 import org.apache.flink.ml.common.param.HasFeaturesCol;
 import org.apache.flink.ml.common.param.HasPredictionCol;
-import org.apache.flink.ml.param.IntParam;
-import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.param.ParamValidators;
 
 /**
- * Params of {@link KMeansModel}.
+ * Params of {@link KMeansModel} and {@link OnlineKMeansModel}.
  *
  * @param <T> The class type of this instance.
  */
 public interface KMeansModelParams<T>
-        extends HasDistanceMeasure<T>, HasFeaturesCol<T>, HasPredictionCol<T> {
-
-    Param<Integer> K =

Review comment:
       Would it still be useful for users to query the k of the KMeansModel (and clustering algorithm in general)?
   
   Note that Spark's `python/pyspark/mlib/clustering.py` line 54 shows example python code snippet that prints the k.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r835697293



##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java
##########
@@ -34,6 +35,7 @@ public static EuclideanDistanceMeasure getInstance() {
 
     @Override
     public double distance(Vector v1, Vector v2) {
+        Preconditions.checkArgument(v1.size() == v2.size());
         double squaredDistance = 0.0;
 
         for (int i = 0; i < v1.size(); i++) {

Review comment:
       I agree to improve it with BLAS but maybe in another PR. I'll leave a TODO here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836040424



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();

Review comment:
       I would suggest we should check in the technically correct implementation at each step. If it is not useless later, we then delete it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836084920



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       As pointed out by @yunfengzhou-hub , a use case could be:
   > I think even if the number of elements in a batch is smaller than parallelism, the training process should still function correctly, because when batches are created according to time window, the operator would have no idea about how many elements are in each batch.
   
   By `consistent user experience`, I mean that given an input of training set, users should be able to run it on one worker or arbitrary number of workers.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836084920



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       As pointed out by @yunfengzhou-hub , a use case could be:
   ```I think even if the number of elements in a batch is smaller than parallelism, the training process should still function correctly, because when batches are created according to time window, the operator would have no idea about how many elements are in each batch.```
   
   By `consistent user experience`, I mean that given a input of training set, users should be able to run it on one worker or arbitrary number of workers.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836004944



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, k, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    /**
+     * Operator that collects a KMeansModelData from each upstream subtask, and outputs the weight
+     * average of collected model data.
+     */
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    /**
+     * An operator that updates KMeans model data locally. It mainly does the following operations.
+     *
+     * <ul>
+     *   <li>Finds the closest centroid id (cluster) of the input points
+     *   <li>Computes the new centroids from the average of input points that belongs to the same
+     *       cluster
+     *   <li>Computes the weighted average of current and new centroids. The weight of a new
+     *       centroid is the number of input points that belong to this cluster. The weight of a
+     *       current centroid is its original weight scaled by $ decayFactor / parallelism $.
+     *   <li>Generates new model data from the weighted average of centroids, and the sum of
+     *       weights.
+     * </ul>
+     */
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchDataState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, int k, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            Preconditions.checkArgument(modelDataRecord.getValue().centroids.length == k);
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);

Review comment:
       `getUniqueElement()` also guarantee that there should be exactly 1 element. This is because it throws `IllegalStateException` if there are more than 1 element in the state.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836003430



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();

Review comment:
       Note that one of our implementation strategy is to reduce the amount of code checkin that will be thrown away soon.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834867544



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java
##########
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SourceFunction} implementation that can directly receive records from tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySourceFunction<T> extends RichSourceFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+    private boolean isRunning = true;
+
+    public InMemorySourceFunction() {
+        id = UUID.randomUUID();
+        queue = new LinkedBlockingQueue();
+        queueMap.put(id, queue);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        queue = queueMap.get(id);
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        queueMap.remove(id);
+    }
+
+    @Override
+    public void run(SourceContext<T> context) throws InterruptedException {
+        while (isRunning) {
+            context.collect(queue.poll(1, TimeUnit.MINUTES));

Review comment:
       I agree. I'll make the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832769140



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,569 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ *
+ * <p>NOTE: This class's current naive implementation performs the training process in a
+ * single-threaded way. Correctness is not affected but there are performance issues.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK(),
+                        getDims());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dims;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dims) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dims = dims;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dims, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dims;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dims, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dims = dims;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            TypeInformation.of(Integer.class)));

Review comment:
       I agree that `BasicTypeInfo.INT_TYPE_INFO` is a simpler choice. I'll make the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832767085



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,569 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ *
+ * <p>NOTE: This class's current naive implementation performs the training process in a
+ * single-threaded way. Correctness is not affected but there are performance issues.

Review comment:
       Yes. I'll remove it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832099135



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {
+    private static final Map<String, Object> map = new HashMap<>();
+
+    private static long counter = 0;
+
+    /**
+     * Returns a prefix string to make sure key-value pairs created in different test cases would
+     * not have the same key.
+     */
+    public static synchronized String createNonDuplicatePrefix() {

Review comment:
       Got it. As we are using InMemoryReporter and thus won't be able to implement this in MockKVStore, I'll apply this principle to the design of simplified test blocking queue mechanisms.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832001871



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;

Review comment:
       I think keeping `OnlineKMeans` in `kmeans` package is more suitable. Because
   - I hardly feel that `OnlineKMeans` and `KMeans` are different algorithms. They are the same algorithm with different external condition. 
   - In implementation, `OnlineKMeans` and `KMeans` have exclusively shared some infra codes, like `KMeansModelData`, `KMeansParams` and `findClosestCentroidId`. If they are put into different packages, I am worried that there will be high cohesion between `kmeans` and onlinekmeans` packages.
   
   The reasons above are not definitive and I would still be glad to make the change if we have other more important reasons.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r831989881



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestMetricReporter.java
##########
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.metrics.Metric;
+import org.apache.flink.metrics.MetricConfig;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.metrics.reporter.MetricReporter;
+import org.apache.flink.metrics.reporter.Scheduled;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** A subclass of {@link MetricReporter} that outputs metrics to a global map. */
+public class TestMetricReporter implements MetricReporter, Scheduled {

Review comment:
       Got it. I'll use InMemoryReporter.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r830079643



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestMetricReporter.java
##########
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.metrics.Metric;
+import org.apache.flink.metrics.MetricConfig;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.metrics.reporter.MetricReporter;
+import org.apache.flink.metrics.reporter.Scheduled;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** A subclass of {@link MetricReporter} that outputs metrics to a global map. */
+public class TestMetricReporter implements MetricReporter, Scheduled {

Review comment:
       Is there anyway to re-use `InMemoryReporter` from Flink? According to its Java doc, it seems that `InMemoryReporter` should exactly meet the purpose of this class. 
   
   And the `MetricGroup` used in `InMemoryReporter` could be used to fulfill the purpose of the prefix here.
   
   Instead of storing/featuring metrics values in a map, could we instead just get the metric object directly like this:
   
   
   `((Gauge) InMemoryReporter::findMetric(...).get()).getValue()`

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,511 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.GlobalBlockingQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest {

Review comment:
       All existing tests in Flink ML extends `AbstractTestBase`. `AbstractTestBase` could improve unit test performance (according to its Java doc), cleanup running jobs after each unit test, provides logging etc.
   
   Should `OnlineKMeansTest` also extend `AbstractTestBase` for the benefits mentioned above and for consistency?
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r830429430



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);

Review comment:
       Sounds good. Let's leave it as a TODO and explain it in the Java doc.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r833154781



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {

Review comment:
       Using the word `maybe` seems that the method does  not have a guaranteed behavior. I'll rename it to `alignAndComputeModelData`, as it outputs one record after receiving one input from each upstream operator.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834922711



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
##########
@@ -47,8 +47,16 @@
 
     public DenseVector[] centroids;
 
+    public DenseVector weights;
+
+    public KMeansModelData(DenseVector[] centroids, DenseVector weights) {
+        this.centroids = centroids;
+        this.weights = weights;
+    }
+
     public KMeansModelData(DenseVector[] centroids) {
         this.centroids = centroids;
+        this.weights = new DenseVector(centroids.length);

Review comment:
       I agree. Besides, when users use `OnlineKMeansModel.setModelData` directly, they should be allowed to set weight to `null` or `Vectors.dense()` (empty value), as the prediction process don't care about weights.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836003024



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();

Review comment:
       I think our long term goal is `A technically correctly implementation, and also work on large dataset in production`.
   
   If our algorithm only works on so-called small dataset, given that there is no guarantee on what is `small`, it effectively means our algorithm can not be used in production. If our algorithm can not be used in production, it won't matter whether we provide checkpoint or not.
   
   Given that we will anyway need to make our algorithm used in production, it is necessary to address the TODO sooner or later. Then it seems simpler not to add the checkpoint code which we will throw away in the near future.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r835703295



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();

Review comment:
       It won't be checkpointed in the current implementation. I had considered that we would solve the TODO above soon, and the checkpoint mechanism is implemented in that infra code. Do you think we should enable checkpoint in this PR?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834930203



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {

Review comment:
       About the first issue, it means we still need to save `dim`, `weight` and `initMode` as metadata of this stage, and the most handy choice is to register them as WithParams interface. But we still want to expose `setRandomCentroids` as the API, instead of `setInitMode().setDim().setWeight()`, right? Does it mean we should still define these parameters, but not expose their get/set methods?
   
   I agree with the second issue.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836012984



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, k, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    /**
+     * Operator that collects a KMeansModelData from each upstream subtask, and outputs the weight
+     * average of collected model data.
+     */
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    /**
+     * An operator that updates KMeans model data locally. It mainly does the following operations.
+     *
+     * <ul>
+     *   <li>Finds the closest centroid id (cluster) of the input points
+     *   <li>Computes the new centroids from the average of input points that belongs to the same
+     *       cluster
+     *   <li>Computes the weighted average of current and new centroids. The weight of a new
+     *       centroid is the number of input points that belong to this cluster. The weight of a
+     *       current centroid is its original weight scaled by $ decayFactor / parallelism $.
+     *   <li>Generates new model data from the weighted average of centroids, and the sum of
+     *       weights.
+     * </ul>
+     */
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchDataState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, int k, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            Preconditions.checkArgument(modelDataRecord.getValue().centroids.length == k);
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);

Review comment:
       Got it. As `getUniqueElement` has already thrown the exception, I'll change the implementation here to `get()`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r830430626



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();
+
+            List<DenseVector> weightsList = IteratorUtils.toList(weightsState.get().iterator());
+            if (weightsList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + weightsList.size()
+                                + " list of weights in this round");
+            }
+            DenseVector weights = weightsList.get(0);
+            weightsState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);

Review comment:
       I am not sure that the model data and input data would be inputted into this operator in lockstep, when we take into consideration the physical transmission latency of those data.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r830429628



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();

Review comment:
       How do we know that "if there are zero batches of data when model data records are received"?
   
   The input data stream and model data stream are just two inputs to this operator. And we don't exactly know the delay from when the feedback data is emitted to when the feedback data is input back to the operator.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829710601



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {

Review comment:
       `TestMetricReporter` is not singleton, as you can see the following in the test:
   ```java
   config.setString(
           "metrics.reporter.test_reporter.class", TestMetricReporter.class.getName());
   ```
   Flink will use this information to instantiate `TestMetricReporter` on its own, so it will not be singleton.
   
   But I agree that we can remove `MockKVStore` for now as there is only `TestMetricReporter` using it. The `TestMeticReporter` can store values in its own static variables and provide the static get method to acquire them from Flink clients in test cases.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829745929



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {
+    private static final Map<String, Object> map = new HashMap<>();
+
+    private static long counter = 0;
+
+    /**
+     * Returns a prefix string to make sure key-value pairs created in different test cases would
+     * not have the same key.
+     */
+    public static synchronized String createNonDuplicatePrefix() {

Review comment:
       I agree that it would be better to remove this method. But in this case a random int64 value might bring risk, as test cases running in parallel might generate identical keys. I'll try to avoid this problem by using a more complex key pattern, like composing class name and a monotonically increasing int64.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836010275



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       Yes, I think the reason is related to performance.
   
   If `parallelism > batchSize`,  it effectively means some slot (with its CPU resource) it definitely wasted. Is there any reason user would want to do this? If not, it means user must have chosen this setup by mistake. Would it be more user friendly to alert user of this issue by throwing an exception?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836041520



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       I hold a different opinion on this. 
   We probably should also support when the paralleism is smaller than batch size for consistent user experience. But we could allert the user by a warning.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r821608309



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/StreamingKMeansTest.java
##########
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeans;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.MockKVStore;
+import org.apache.flink.ml.util.MockMessageQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link StreamingKMeans} and {@link StreamingKMeansModel}. */
+public class StreamingKMeansTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private String trainId;
+    private String predictId;
+    private String outputId;
+    private String modelDataId;
+    private String metricReporterPrefix;
+    private String modelDataVersionGaugeKey;
+    private String currentModelDataVersion;
+    private List<String> queueIds;
+    private List<String> kvStoreKeys;
+    private List<JobClient> clients;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() {
+        metricReporterPrefix = MockKVStore.createNonDuplicatePrefix();
+        kvStoreKeys = new ArrayList<>();
+        kvStoreKeys.add(metricReporterPrefix);
+
+        modelDataVersionGaugeKey =
+                TestMetricReporter.getKey(metricReporterPrefix, "modelDataVersion");
+
+        currentModelDataVersion = "0";
+
+        trainId = MockMessageQueues.createMessageQueue();
+        predictId = MockMessageQueues.createMessageQueue();
+        outputId = MockMessageQueues.createMessageQueue();
+        modelDataId = MockMessageQueues.createMessageQueue();
+        queueIds = new ArrayList<>();
+        queueIds.addAll(Arrays.asList(trainId, predictId, outputId, modelDataId));
+
+        clients = new ArrayList<>();
+
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        config.setString("metrics.reporters", "test_reporter");
+        config.setString(
+                "metrics.reporter.test_reporter.class", TestMetricReporter.class.getName());
+        config.setString("metrics.reporter.test_reporter.interval", "100 MILLISECONDS");
+        config.setString("metrics.reporter.test_reporter.prefix", metricReporterPrefix);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(trainId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(predictId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() {
+        for (JobClient client : clients) {
+            try {
+                client.cancel();
+            } catch (IllegalStateException e) {
+                if (!e.getMessage()
+                        .equals("MiniCluster is not yet running or has already been shut down.")) {
+                    throw e;
+                }
+            }
+        }
+        clients.clear();
+
+        for (String queueId : queueIds) {
+            MockMessageQueues.deleteBlockingQueue(queueId);
+        }
+        queueIds.clear();
+
+        for (String key : kvStoreKeys) {
+            MockKVStore.remove(key);
+        }
+        kvStoreKeys.clear();
+    }
+
+    /** Adds sinks for StreamingKMeansModel's transform output and model data. */
+    private void configModelSink(StreamingKMeansModel streamingModel) {
+        Table outputTable = streamingModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(new MockSinkFunction<>(outputId));
+
+        Table modelDataTable = streamingModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(new MockSinkFunction<>(modelDataId));
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (!MockKVStore.containsKey(modelDataVersionGaugeKey)) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion = MockKVStore.get(modelDataVersionGaugeKey);
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        MockMessageQueues.offerAll(predictId, StreamingKMeansTest.predictData);
+        List<Row> rawResult =
+                MockMessageQueues.poll(outputId, StreamingKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        StreamingKMeans streamingKMeans = new StreamingKMeans();
+        Assert.assertEquals("features", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, streamingKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(2, streamingKMeans.getK());
+        Assert.assertEquals(1, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(1, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0., streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(StreamingKMeans.class.getName().hashCode(), streamingKMeans.getSeed());
+
+        streamingKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(3, streamingKMeans.getK());
+        Assert.assertEquals(5, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(5, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0.25, streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", streamingKMeans.getInitMode());
+        Assert.assertEquals(100, streamingKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithOfflineKMeans() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setDecayFactor(100.0)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setDims(2)
+                        .setBatchSize(6)
+                        .setInitWeights(new Double[] {0., 0.});
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeans loadedKMeans =
+                StageTestUtils.saveAndReload(
+                        env, streamingKMeans, tempFolder.newFolder().getAbsolutePath());
+
+        StreamingKMeansModel streamingModel = loadedKMeans.fit(trainTable);
+
+        String modelDataPassId = MockMessageQueues.createMessageQueue();
+        queueIds.add(modelDataPassId);
+
+        String modelSavePath = tempFolder.newFolder().getAbsolutePath();
+        streamingModel.save(modelSavePath);
+        KMeansModelData.getModelDataStream(streamingModel.getModelData()[0])
+                .addSink(new MockSinkFunction<>(modelDataPassId));
+        clients.add(env.executeAsync());

Review comment:
       Even if all the sinks used in `save()` is bounded, `env.execute()` would still block infinitely so long as there is unbounded source in the environment, so I have to use `executeAsync()` here. This would affect `Pipeline`/`Graph`'s saving process. How could this be resolved?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834919072



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,464 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();

Review comment:
       OK. Could we add a TODO and mention what/how to do optimization after upgrading to Flink 1.15?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,464 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();

Review comment:
       Could we add a TODO and mention what/how to do optimization after upgrading to Flink 1.15?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829794798



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);

Review comment:
       Not with the current implementation. I tried making the output of the iteration body to be the input feedback stream, and this modification can guarantee this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829707954



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/StreamingKMeansTest.java
##########
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeans;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.MockKVStore;
+import org.apache.flink.ml.util.MockMessageQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link StreamingKMeans} and {@link StreamingKMeansModel}. */
+public class StreamingKMeansTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private String trainId;
+    private String predictId;
+    private String outputId;
+    private String modelDataId;
+    private String metricReporterPrefix;
+    private String modelDataVersionGaugeKey;
+    private String currentModelDataVersion;
+    private List<String> queueIds;
+    private List<String> kvStoreKeys;
+    private List<JobClient> clients;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() {
+        metricReporterPrefix = MockKVStore.createNonDuplicatePrefix();
+        kvStoreKeys = new ArrayList<>();
+        kvStoreKeys.add(metricReporterPrefix);
+
+        modelDataVersionGaugeKey =
+                TestMetricReporter.getKey(metricReporterPrefix, "modelDataVersion");
+
+        currentModelDataVersion = "0";
+
+        trainId = MockMessageQueues.createMessageQueue();
+        predictId = MockMessageQueues.createMessageQueue();
+        outputId = MockMessageQueues.createMessageQueue();
+        modelDataId = MockMessageQueues.createMessageQueue();
+        queueIds = new ArrayList<>();
+        queueIds.addAll(Arrays.asList(trainId, predictId, outputId, modelDataId));
+
+        clients = new ArrayList<>();
+
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        config.setString("metrics.reporters", "test_reporter");
+        config.setString(
+                "metrics.reporter.test_reporter.class", TestMetricReporter.class.getName());
+        config.setString("metrics.reporter.test_reporter.interval", "100 MILLISECONDS");
+        config.setString("metrics.reporter.test_reporter.prefix", metricReporterPrefix);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(trainId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(predictId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() {
+        for (JobClient client : clients) {
+            try {
+                client.cancel();
+            } catch (IllegalStateException e) {
+                if (!e.getMessage()
+                        .equals("MiniCluster is not yet running or has already been shut down.")) {
+                    throw e;
+                }
+            }
+        }
+        clients.clear();
+
+        for (String queueId : queueIds) {
+            MockMessageQueues.deleteBlockingQueue(queueId);
+        }
+        queueIds.clear();
+
+        for (String key : kvStoreKeys) {
+            MockKVStore.remove(key);
+        }
+        kvStoreKeys.clear();
+    }
+
+    /** Adds sinks for StreamingKMeansModel's transform output and model data. */
+    private void configModelSink(StreamingKMeansModel streamingModel) {
+        Table outputTable = streamingModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(new MockSinkFunction<>(outputId));
+
+        Table modelDataTable = streamingModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(new MockSinkFunction<>(modelDataId));
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (!MockKVStore.containsKey(modelDataVersionGaugeKey)) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion = MockKVStore.get(modelDataVersionGaugeKey);
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        MockMessageQueues.offerAll(predictId, StreamingKMeansTest.predictData);
+        List<Row> rawResult =
+                MockMessageQueues.poll(outputId, StreamingKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        StreamingKMeans streamingKMeans = new StreamingKMeans();
+        Assert.assertEquals("features", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, streamingKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(2, streamingKMeans.getK());
+        Assert.assertEquals(1, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(1, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0., streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(StreamingKMeans.class.getName().hashCode(), streamingKMeans.getSeed());
+
+        streamingKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(3, streamingKMeans.getK());
+        Assert.assertEquals(5, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(5, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0.25, streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", streamingKMeans.getInitMode());
+        Assert.assertEquals(100, streamingKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithOfflineKMeans() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setDecayFactor(100.0)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setDims(2)
+                        .setBatchSize(6)
+                        .setInitWeights(new Double[] {0., 0.});
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeans loadedKMeans =
+                StageTestUtils.saveAndReload(
+                        env, streamingKMeans, tempFolder.newFolder().getAbsolutePath());
+
+        StreamingKMeansModel streamingModel = loadedKMeans.fit(trainTable);
+
+        String modelDataPassId = MockMessageQueues.createMessageQueue();
+        queueIds.add(modelDataPassId);
+
+        String modelSavePath = tempFolder.newFolder().getAbsolutePath();
+        streamingModel.save(modelSavePath);
+        KMeansModelData.getModelDataStream(streamingModel.getModelData()[0])
+                .addSink(new MockSinkFunction<>(modelDataPassId));
+        clients.add(env.executeAsync());

Review comment:
       In current `Pipeline`/`Graph`'s `save()` method, I think we have implicitly assumed that calling `execute()` after `save()` should always unblock the process, while this is not true if online algorithms are involved. If this is not the case then using `executeAsync()` would not be an issue.

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {

Review comment:
       `TestMetricReporter` is not singleton, as you can see the following in the test:
   ```java
   config.setString(
           "metrics.reporter.test_reporter.class", TestMetricReporter.class.getName());
   ```
   Flink will use this information to instantiate `TestMetricReporter` on its own, so it will not be singleton.
   
   But I agree that we can remove `MockKVStore` for now as there is only `TestMetricReporter` using it. The `TestMeticReporter` can store values in its own static variables and provide the static get method to acquire them from Flink clients in test cases.

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockMessageQueues.java
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A class that manages global message queues used in unit tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class MockMessageQueues {
+    private static final Map<String, BlockingQueue> queueMap = new HashMap<>();
+    private static long counter = 0;
+
+    public static synchronized String createMessageQueue() {
+        String id = String.valueOf(counter);
+        queueMap.put(id, new LinkedBlockingQueue<>());
+        counter++;
+        return id;
+    }
+
+    @SafeVarargs
+    public static <T> void offerAll(String id, T... values) throws InterruptedException {
+        for (T value : values) {
+            offer(id, value);
+        }
+    }
+
+    public static <T> void offer(String id, T value) throws InterruptedException {
+        offer(id, value, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> void offer(String id, T value, long timeout, TimeUnit unit)
+            throws InterruptedException {
+        boolean success = queueMap.get(id).offer(value, timeout, unit);
+        if (!success) {
+            throw new RuntimeException(
+                    "Failed to offer " + value + " to blocking queue " + id + ".");
+        }
+    }
+
+    public static <T> List<T> poll(String id, int num) throws InterruptedException {
+        List<T> result = new ArrayList<>();
+        for (int i = 0; i < num; i++) {
+            result.add(poll(id));
+        }
+        return result;
+    }
+
+    public static <T> T poll(String id) throws InterruptedException {
+        return poll(id, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> T poll(String id, long timeout, TimeUnit unit) throws InterruptedException {
+        T value = (T) queueMap.get(id).poll(timeout, unit);
+        if (value == null) {
+            throw new RuntimeException("Failed to poll next value from blocking queue " + id + ".");
+        }
+        return value;
+    }
+
+    public static void deleteBlockingQueue(String id) {

Review comment:
       In cases when there are multiple test classes running in parallel, I'm afraid that calling a `clear()` method in one test case would affect the process of the others. Thus I would prefer to have each test class specifies and deletes its own queues.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);

Review comment:
       I think it is hard to achieve for now. We need to create mini batches of fixed batch size from train data, but if the parallelism is larger than 1, we do not have a mechanism to count the total number of records received by each subtask.
   
   One possible solution I have wanted to propose is to insert barrier into the train data stream, so that even if train data would be distributed on different subtasks, the subtasks still knows when to finish the current batch so long as it can receive barrier. We have not got a change to discuss this problem and possible solutions offline.
   
   For now I prefer to still have this limit in this PR. I'll add relevant notices in its Javadoc.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();

Review comment:
       The behavior of OnlineKMeans as I have conceived is as follows. In each iteration, the algorithm should consume one batch of train data, and one set of model data received from feedback edge. If there are multiple batches of data waiting to be consumed, the operator would still consume one batch at a time. If there are zero batches of data when model data records are received, the operator would just cache the model data, waiting for train data to come in.
   
   So when model data comes in, it does not consume input points received before it. it just consumed the next batch of data, which might or might not have arrived yet.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();
+
+            List<DenseVector> weightsList = IteratorUtils.toList(weightsState.get().iterator());
+            if (weightsList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + weightsList.size()
+                                + " list of weights in this round");
+            }
+            DenseVector weights = weightsList.get(0);
+            weightsState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);

Review comment:
       As described above, in online algorithm I suppose the training process should produce one model data update for each batch of train data. If there are multiple batches of train data, they should not be merged into one. If there are zero batches of train data, the iteration body would not continue training on an empty batch either.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model

Review comment:
       I agree. I'll reference to Spark's documentation and add similar JavaDocs here.
   
   Besides, Spark's JavaDoc is mainly on `StreamingKMeansModel`, rather than `StreamingKMeans`. I think it should be `StreamingKMeans` that is explained more in detail, and similar to Spark's JavaDoc, my added documentation would be mainly about the training process, so I'll add the detailed documentations on `OnlineKMeans`.

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {
+    private static final Map<String, Object> map = new HashMap<>();
+
+    private static long counter = 0;
+
+    /**
+     * Returns a prefix string to make sure key-value pairs created in different test cases would
+     * not have the same key.
+     */
+    public static synchronized String createNonDuplicatePrefix() {

Review comment:
       I agree that it would be better to remove this method. But in this case a random int64 value might bring risk, as test cases running in parallel might generate identical keys. I'll try to avoid this problem by using a more complex key pattern, like composing class name and a monotonically increasing int64.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDecayFactor.java
##########
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared decay factor param. */
+public interface HasDecayFactor<T> extends WithParams<T> {
+    Param<Double> DECAY_FACTOR =
+            new DoubleParam(
+                    "decayFactor",
+                    "The forgetfulness of the previous centroids.",
+                    0.,
+                    ParamValidators.gtEq(0));

Review comment:
       Spark has not set an upper bound limit to this parameter. It only ensures that decayFactor is nonnegative, as follows.
   ```scala
     def setDecayFactor(a: Double): this.type = {
       require(a >= 0,
         s"Decay factor must be nonnegative but got ${a}")
       this.decayFactor = a
       this
     }
   ```
   I also understand your concern. If this value is larger than 1, then it cannot represent so-called "forgetfulness", as the weight of init model data is always strengthened... Maybe we can make it have to be smaller than 1, and remove this limit when we found relative use cases. What do you think?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());

Review comment:
       I agree. In fact we only need to check `initModelDataTable == null` when initMode is `random`, as if initMode is `direct` the following implementation would naturally fail unless `initModelDataTable != null`.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);

Review comment:
       Not with the current implementation. I tried making the output of the iteration body to be the input feedback stream, and this modification can guarantee this.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeansModel.java
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * StreamingKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link StreamingKMeans}.
+ */
+public class StreamingKMeansModel
+        implements Model<StreamingKMeansModel>, KMeansModelParams<StreamingKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public StreamingKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public StreamingKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> cache = new ArrayList<>();

Review comment:
       OK. I'll make the change.

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -207,25 +211,21 @@ public void testSaveLoadAndPredict() throws Exception {
         KMeansModel loadedModel =
                 StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
         Table output = loadedModel.transform(dataTable)[0];
-        assertEquals(
-                Collections.singletonList("centroids"),
-                loadedModel.getModelData()[0].getResolvedSchema().getColumnNames());
         assertEquals(
                 Arrays.asList("features", "prediction"),
                 output.getResolvedSchema().getColumnNames());
 
+        List<Row> results = IteratorUtils.toList(output.execute().collect());
         List<Set<DenseVector>> actualGroups =
-                executeAndCollect(output, kmeans.getFeaturesCol(), kmeans.getPredictionCol());
+                groupFeaturesByPrediction(
+                        results, kmeans.getFeaturesCol(), kmeans.getPredictionCol());
         assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
     }
 
     @Test
     public void testGetModelData() throws Exception {
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
         KMeansModel model = kmeans.fit(dataTable);
-        assertEquals(
-                Collections.singletonList("centroids"),

Review comment:
       I had once changed the design of `KMeansModelData` to make it contain more than centroids (adding weights field), and this removal is the result of that change. Now I have recovered `KMeansModelData`'s structure, but forgot to add back this check. I'll fix it.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {

Review comment:
       Got it. I'll also apply it to `KMeans`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829718183



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();
+
+            List<DenseVector> weightsList = IteratorUtils.toList(weightsState.get().iterator());
+            if (weightsList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + weightsList.size()
+                                + " list of weights in this round");
+            }
+            DenseVector weights = weightsList.get(0);
+            weightsState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);

Review comment:
       As described above, in online algorithm I suppose the training process should produce one model data update for each batch of train data. If there are multiple batches of train data, they should not be merged into one. If there are zero batches of train data, the iteration body would not continue training on an empty batch either.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829717356



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();

Review comment:
       The behavior of OnlineKMeans as I have conceived is as follows. In each iteration, the algorithm should consume one batch of train data, and one set of model data received from feedback edge. If there are multiple batches of data waiting to be consumed, the operator would still consume one batch at a time. If there are zero batches of data when model data records are received, the operator would just cache the model data, waiting for train data to come in.
   
   So when model data comes in, it does not consume input points received before it. it just consumed the next batch of data, which might or might not have arrived yet.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829796038



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeansModel.java
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * StreamingKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link StreamingKMeans}.
+ */
+public class StreamingKMeansModel
+        implements Model<StreamingKMeansModel>, KMeansModelParams<StreamingKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public StreamingKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public StreamingKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> cache = new ArrayList<>();

Review comment:
       OK. I'll make the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r830048499



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;

Review comment:
       Since OnlineKMeans is a different algorithm from KMeans, should we move its files to a dedicated folder, e.g. `onlinekmeans`?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/StreamingKMeansTest.java
##########
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeans;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.MockKVStore;
+import org.apache.flink.ml.util.MockMessageQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link StreamingKMeans} and {@link StreamingKMeansModel}. */
+public class StreamingKMeansTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private String trainId;
+    private String predictId;
+    private String outputId;
+    private String modelDataId;
+    private String metricReporterPrefix;
+    private String modelDataVersionGaugeKey;
+    private String currentModelDataVersion;
+    private List<String> queueIds;
+    private List<String> kvStoreKeys;
+    private List<JobClient> clients;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() {
+        metricReporterPrefix = MockKVStore.createNonDuplicatePrefix();
+        kvStoreKeys = new ArrayList<>();
+        kvStoreKeys.add(metricReporterPrefix);
+
+        modelDataVersionGaugeKey =
+                TestMetricReporter.getKey(metricReporterPrefix, "modelDataVersion");
+
+        currentModelDataVersion = "0";
+
+        trainId = MockMessageQueues.createMessageQueue();
+        predictId = MockMessageQueues.createMessageQueue();
+        outputId = MockMessageQueues.createMessageQueue();
+        modelDataId = MockMessageQueues.createMessageQueue();
+        queueIds = new ArrayList<>();
+        queueIds.addAll(Arrays.asList(trainId, predictId, outputId, modelDataId));
+
+        clients = new ArrayList<>();
+
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        config.setString("metrics.reporters", "test_reporter");
+        config.setString(
+                "metrics.reporter.test_reporter.class", TestMetricReporter.class.getName());
+        config.setString("metrics.reporter.test_reporter.interval", "100 MILLISECONDS");
+        config.setString("metrics.reporter.test_reporter.prefix", metricReporterPrefix);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(trainId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(predictId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() {
+        for (JobClient client : clients) {
+            try {
+                client.cancel();
+            } catch (IllegalStateException e) {
+                if (!e.getMessage()
+                        .equals("MiniCluster is not yet running or has already been shut down.")) {
+                    throw e;
+                }
+            }
+        }
+        clients.clear();
+
+        for (String queueId : queueIds) {
+            MockMessageQueues.deleteBlockingQueue(queueId);
+        }
+        queueIds.clear();
+
+        for (String key : kvStoreKeys) {
+            MockKVStore.remove(key);
+        }
+        kvStoreKeys.clear();
+    }
+
+    /** Adds sinks for StreamingKMeansModel's transform output and model data. */
+    private void configModelSink(StreamingKMeansModel streamingModel) {
+        Table outputTable = streamingModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(new MockSinkFunction<>(outputId));
+
+        Table modelDataTable = streamingModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(new MockSinkFunction<>(modelDataId));
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (!MockKVStore.containsKey(modelDataVersionGaugeKey)) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion = MockKVStore.get(modelDataVersionGaugeKey);
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        MockMessageQueues.offerAll(predictId, StreamingKMeansTest.predictData);
+        List<Row> rawResult =
+                MockMessageQueues.poll(outputId, StreamingKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        StreamingKMeans streamingKMeans = new StreamingKMeans();
+        Assert.assertEquals("features", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, streamingKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(2, streamingKMeans.getK());
+        Assert.assertEquals(1, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(1, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0., streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(StreamingKMeans.class.getName().hashCode(), streamingKMeans.getSeed());
+
+        streamingKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(3, streamingKMeans.getK());
+        Assert.assertEquals(5, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(5, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0.25, streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", streamingKMeans.getInitMode());
+        Assert.assertEquals(100, streamingKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithOfflineKMeans() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setDecayFactor(100.0)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        KMeans offlineKMeans =
+                new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setDims(2)
+                        .setBatchSize(6)
+                        .setInitWeights(new Double[] {0., 0.});
+        ReadWriteUtils.updateExistingParams(streamingKMeans, offlineKMeans.getParamMap());
+
+        StreamingKMeans loadedKMeans =
+                StageTestUtils.saveAndReload(
+                        env, streamingKMeans, tempFolder.newFolder().getAbsolutePath());
+
+        StreamingKMeansModel streamingModel = loadedKMeans.fit(trainTable);
+
+        String modelDataPassId = MockMessageQueues.createMessageQueue();
+        queueIds.add(modelDataPassId);
+
+        String modelSavePath = tempFolder.newFolder().getAbsolutePath();
+        streamingModel.save(modelSavePath);
+        KMeansModelData.getModelDataStream(streamingModel.getModelData()[0])
+                .addSink(new MockSinkFunction<>(modelDataPassId));
+        clients.add(env.executeAsync());

Review comment:
       Sounds good.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r831988352



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();

Review comment:
       Same as the previous comment, having a single-threaded operator to distribute mini batch solves this discussion.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832102967



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockMessageQueues.java
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A class that manages global message queues used in unit tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class MockMessageQueues {
+    private static final Map<String, BlockingQueue> queueMap = new HashMap<>();
+    private static long counter = 0;
+
+    public static synchronized String createMessageQueue() {
+        String id = String.valueOf(counter);
+        queueMap.put(id, new LinkedBlockingQueue<>());
+        counter++;
+        return id;
+    }
+
+    @SafeVarargs
+    public static <T> void offerAll(String id, T... values) throws InterruptedException {
+        for (T value : values) {
+            offer(id, value);
+        }
+    }
+
+    public static <T> void offer(String id, T value) throws InterruptedException {
+        offer(id, value, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> void offer(String id, T value, long timeout, TimeUnit unit)
+            throws InterruptedException {
+        boolean success = queueMap.get(id).offer(value, timeout, unit);
+        if (!success) {
+            throw new RuntimeException(
+                    "Failed to offer " + value + " to blocking queue " + id + ".");
+        }
+    }
+
+    public static <T> List<T> poll(String id, int num) throws InterruptedException {
+        List<T> result = new ArrayList<>();
+        for (int i = 0; i < num; i++) {
+            result.add(poll(id));
+        }
+        return result;
+    }
+
+    public static <T> T poll(String id) throws InterruptedException {
+        return poll(id, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> T poll(String id, long timeout, TimeUnit unit) throws InterruptedException {
+        T value = (T) queueMap.get(id).poll(timeout, unit);
+        if (value == null) {
+            throw new RuntimeException("Failed to poll next value from blocking queue " + id + ".");
+        }
+        return value;
+    }
+
+    public static void deleteBlockingQueue(String id) {

Review comment:
       I tried mocking the implementation of `InMemoryReporter` and introduced `InMemorySourceFunction` and `InMemorySinkFunction`. Users now only needs to create instances of these two classes in their test classes, and the static queues are hidden from users. The statically created queues will be automatically deleted when the source/sinks are closed, so users need not delete each of them or delete any `group` concept instances. As there seems not to be a built-in implementation of such classes in Flink, like `InMemoryReporter`, this is the most proper solution I can though of by now. Please have a look at them and see if it is a clean enough solution.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832792959



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockMessageQueues.java
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A class that manages global message queues used in unit tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class MockMessageQueues {
+    private static final Map<String, BlockingQueue> queueMap = new HashMap<>();
+    private static long counter = 0;
+
+    public static synchronized String createMessageQueue() {
+        String id = String.valueOf(counter);
+        queueMap.put(id, new LinkedBlockingQueue<>());
+        counter++;
+        return id;
+    }
+
+    @SafeVarargs
+    public static <T> void offerAll(String id, T... values) throws InterruptedException {
+        for (T value : values) {
+            offer(id, value);
+        }
+    }
+
+    public static <T> void offer(String id, T value) throws InterruptedException {
+        offer(id, value, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> void offer(String id, T value, long timeout, TimeUnit unit)
+            throws InterruptedException {
+        boolean success = queueMap.get(id).offer(value, timeout, unit);
+        if (!success) {
+            throw new RuntimeException(
+                    "Failed to offer " + value + " to blocking queue " + id + ".");
+        }
+    }
+
+    public static <T> List<T> poll(String id, int num) throws InterruptedException {
+        List<T> result = new ArrayList<>();
+        for (int i = 0; i < num; i++) {
+            result.add(poll(id));
+        }
+        return result;
+    }
+
+    public static <T> T poll(String id) throws InterruptedException {
+        return poll(id, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> T poll(String id, long timeout, TimeUnit unit) throws InterruptedException {
+        T value = (T) queueMap.get(id).poll(timeout, unit);
+        if (value == null) {
+            throw new RuntimeException("Failed to poll next value from blocking queue " + id + ".");
+        }
+        return value;
+    }
+
+    public static void deleteBlockingQueue(String id) {

Review comment:
       The approach using `InMemoryReporter /InMemorySourceFunction/InMemorySinkFunction` is much cleaner and readable than the previous one. Thanks!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832001871



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;

Review comment:
       I think keeping `OnlineKMeans` in `kmeans` package is more suitable. Because
   - I hardly feel that `OnlineKMeans` and `KMeans` are not different algorithms. They are the same algorithm with different external condition. 
   - In implementation, `OnlineKMeans` and `KMeans` have exclusively shared some infra codes, like `KMeansModelData`, `KMeansParams` and `findClosestCentroidId`. If they are put into different packages, I am worried that there will be high cohesion between `kmeans` and onlinekmeans` packages.
   
   The reasons above are not definitive and I would still be glad to make the change if we have other more important reasons.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832825345



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();
+
+        // TODO: replace this simple implementation of model data version with the formal API to
+        // track model version after its design is settled.
+        private int modelDataVersion;
+
+        public PredictLabelFunction(String featuresCol, DistanceMeasure distanceMeasure) {
+            this.featuresCol = featuresCol;
+            this.distanceMeasure = distanceMeasure;
+        }
+
+        @Override
+        public void open(Configuration parameters) throws Exception {
+            super.open(parameters);
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            "modelDataVersion",
+                            (Gauge<String>) () -> Integer.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(
+                KMeansModelData modelData,
+                CoProcessFunction<KMeansModelData, Row, Row>.Context context,
+                Collector<Row> collector) {
+            centroids = modelData.centroids;
+            modelDataVersion++;
+            for (Row dataPoint : bufferedPoints) {
+                processElement2(dataPoint, context, collector);
+            }
+            bufferedPoints.clear();
+        }
+
+        @Override
+        public void processElement2(
+                Row dataPoint,
+                CoProcessFunction<KMeansModelData, Row, Row>.Context context,
+                Collector<Row> collector) {
+            if (centroids == null) {
+                bufferedPoints.add(dataPoint);
+                return;
+            }
+            DenseVector point = (DenseVector) dataPoint.getField(featuresCol);
+            int closestCentroidId = KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+            collector.collect(Row.join(dataPoint, Row.of(closestCentroidId)));
+        }
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {

Review comment:
       Would it be useful to add Javadoc explaining what is save here? It might be useful to mention that we don't store model data because it is unbounded.
   
   Could you also update the Javadoc of Stage::save() to explicitly mention that it could save metadata and bounded model data? And unbounded model data is expected to be fetched using Model::getModelData().
   

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {

Review comment:
       Would it be useful to add Javadoc explaining what is save here? It might be useful to mention metadata and the initial model data if exists.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansParams.java
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasDecayFactor;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasSeed;
+import org.apache.flink.ml.param.DoubleArrayParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link OnlineKMeans}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineKMeansParams<T>
+        extends HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasDecayFactor<T>,
+                HasSeed<T>,
+                KMeansModelParams<T> {
+    Param<String> INIT_MODE =
+            new StringParam(
+                    "initMode",
+                    "How to initialize the model data of the online KMeans algorithm. Supported options: 'random', 'direct'.",
+                    "random",
+                    ParamValidators.inArray("random", "direct"));
+
+    Param<Integer> DIM =
+            new IntParam(
+                    "dim",
+                    "The number of dimensions of centroids. Used when initializing random centroids.",
+                    1,
+                    ParamValidators.gt(0));
+
+    Param<Double[]> INIT_WEIGHTS =

Review comment:
       After thinking about this more, I find that it is probably bad to pass initial weights as parameter due to the following reasons:
   
   - The length of initial weights scale with K, which could potentially be a very large number (e.g. 1000+). It would be more performant to store this information like any other model data instead of in the metadata file.
   - Currently we require users to always provide this list of weight values even if users just want to initialize 1000 centroids with `initMode = random`. It is probably more usable for user to specify just one weight value when `initMode = random`.
   
   And if user only needs to provide this list of weights when `initMode == direct` AND model data is provided to the constructor, it might be simpler to just have user provide all these information in one call.
   
   Here is an alternative approach which is similar to Spark's StreamingKMeans:
   - Remove DIM, INIT_MODE and INIT_WEIGHTS parameters.
   - Remove the construct that takes `initModelDataTable`.
   - Add the method `setInitialCentroids(Table modelData, Double[] weights)`.
   - Add the method setRandomCentroids(int dim, double weight).
   
   What do you think?
   

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int dim = centroids[0].size();
+            DenseVector counts = new DenseVector(k);
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts.values[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts.values[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            for (int i = 0; i < k; i++) {
+                if (counts.values[i] < 1e-5) {
+                    continue;
+                }
+                BLAS.scal(1.0 / counts.values[i], sums[i]);
+            }
+
+            output.collect(new StreamRecord<>(Tuple2.of(new KMeansModelData(sums), counts)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    private static class MiniBatchDistributor
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+        private int shift = 0;
+
+        private MiniBatchDistributor(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            // Reduce accumulated imbalance among distributed batches.
+            Collections.rotate(sizes, shift);
+            shift++;

Review comment:
       nits: It seems simpler to just do `shift = (shift + 1) % downStreamParallelism`.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);

Review comment:
       Would it be simpler to just do `DenseVector[] points = pointsList.remove(0)`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int dim = centroids[0].size();
+            DenseVector counts = new DenseVector(k);
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts.values[i] = 0;
+            }
+            for (DenseVector point : points) {

Review comment:
       Could we also verify that `point.size() == dim`? Maybe we can add this check in `EuclideanDistanceMeasure::distance`.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);

Review comment:
       Would it be simpler to just do `miniBatchState.update(pointsList)`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int dim = centroids[0].size();
+            DenseVector counts = new DenseVector(k);
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts.values[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts.values[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            for (int i = 0; i < k; i++) {
+                if (counts.values[i] < 1e-5) {
+                    continue;
+                }
+                BLAS.scal(1.0 / counts.values[i], sums[i]);
+            }
+
+            output.collect(new StreamRecord<>(Tuple2.of(new KMeansModelData(sums), counts)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    private static class MiniBatchDistributor

Review comment:
       Could we add Javadoc explaining what this operator does? For example, we can mention that it attempts to evenly distribute a list of globalBatchSize elements across downstream operators.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int dim = centroids[0].size();
+            DenseVector counts = new DenseVector(k);
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts.values[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts.values[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            for (int i = 0; i < k; i++) {
+                if (counts.values[i] < 1e-5) {
+                    continue;
+                }
+                BLAS.scal(1.0 / counts.values[i], sums[i]);
+            }
+
+            output.collect(new StreamRecord<>(Tuple2.of(new KMeansModelData(sums), counts)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    private static class MiniBatchDistributor
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+        private int shift = 0;
+
+        private MiniBatchDistributor(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            // Reduce accumulated imbalance among distributed batches.
+            Collections.rotate(sizes, shift);
+            shift++;
+            shift %= downStreamParallelism;
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class MiniBatchCreator

Review comment:
       It actually creates a global batch for each step, rather than a mini batch, right?
   
   Maybe we can rename `MiniBatchCreator` as `GlobalBatchCreator`. And rename `MiniBatchDistributor` as `MiniBatchCreator` or `GlobalBatchSplitter`. 
   

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int dim = centroids[0].size();
+            DenseVector counts = new DenseVector(k);
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts.values[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts.values[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            for (int i = 0; i < k; i++) {
+                if (counts.values[i] < 1e-5) {
+                    continue;
+                }
+                BLAS.scal(1.0 / counts.values[i], sums[i]);
+            }
+
+            output.collect(new StreamRecord<>(Tuple2.of(new KMeansModelData(sums), counts)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    private static class MiniBatchDistributor
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+        private int shift = 0;
+
+        private MiniBatchDistributor(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            // Reduce accumulated imbalance among distributed batches.
+            Collections.rotate(sizes, shift);

Review comment:
       Hmm.. It seems that the algorithm should have exactly the same performance and accuracy even if we remove this statement, right?
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r833141053



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int dim = centroids[0].size();
+            DenseVector counts = new DenseVector(k);
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts.values[i] = 0;
+            }
+            for (DenseVector point : points) {

Review comment:
       OK. I'll add the check in `EuclideanDistanceMeasure`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832171267



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private String currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = "0";
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void configTransformAndSink(OnlineKMeansModel onlineModel) {

Review comment:
       nits: it is not clear what `config` in the method name refers to.
   
   How about renaming it as `transformAndOutputData(...)`?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SourceFunction} implementation that can directly receive records from tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySourceFunction<T> extends RichSourceFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+    private boolean isRunning = true;
+
+    public InMemorySourceFunction() {
+        id = UUID.randomUUID();
+        queue = new LinkedBlockingQueue();
+        queueMap.put(id, queue);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        queue = queueMap.get(id);
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        queueMap.remove(id);
+    }
+
+    @Override
+    public void run(SourceContext<T> context) {
+        while (isRunning) {
+            T value = queue.poll();
+            if (value == null) {
+                Thread.yield();

Review comment:
       Would it be more performant and simpler to use `queue.take()` to get values from the queue? This is a blocking call which reduce the chance of busy loop.
   
   And if we agree to do this, we need to make sure the `InMemorySourceFunction` will not block on this queue forever when the test finishes. One approach is to add a dummy value to the queue when `cancel()` is invoked.

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private String currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = "0";
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void configTransformAndSink(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(predictTable)[0];

Review comment:
       nits: Would it be a bit simpler to do the following:
   
   ```
   Table outputTable = onlineModel.transform(predictTable)[0];
   tEnv.toDataStream(outputTable).addSink(outputSink);
   
   Table modelDataTable = onlineModel.getModelData()[0];
   KMeansModelData.getModelDataStream(modelDataTable).addSink(modelDataSink);
   ```

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private String currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = "0";
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void configTransformAndSink(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion =
+                    String.valueOf(
+                            reporter.findMetrics("modelDataVersion").values().stream()
+                                    .map(x -> Integer.parseInt(((Gauge<String>) x).getValue()))
+                                    .min(Integer::compareTo)
+                                    .orElse(Integer.parseInt(currentModelDataVersion)));
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        predictSource.offerAll(OnlineKMeansTest.predictData);
+        List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        OnlineKMeans onlineKMeans = new OnlineKMeans();
+        Assert.assertEquals("features", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", onlineKMeans.getInitMode());
+        Assert.assertEquals(2, onlineKMeans.getK());
+        Assert.assertEquals(1, onlineKMeans.getDims());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(1, onlineKMeans.getBatchSize());
+        Assert.assertEquals(0., onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", onlineKMeans.getInitMode());
+        Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), onlineKMeans.getSeed());
+
+        onlineKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals(3, onlineKMeans.getK());
+        Assert.assertEquals(5, onlineKMeans.getDims());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(5, onlineKMeans.getBatchSize());
+        Assert.assertEquals(0.25, onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", onlineKMeans.getInitMode());
+        Assert.assertEquals(100, onlineKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);
+        configTransformAndSink(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        waitInitModelDataSetup();
+
+        trainSource.offerAll(trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+
+        trainSource.offerAll(trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithKMeans() throws Exception {
+        KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel kMeansModel = kMeans.fit(offlineTrainTable);
+
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans(kMeansModel.getModelData())
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);
+        configTransformAndSink(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+
+        trainSource.offerAll(trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel kMeansModel = kMeans.fit(offlineTrainTable);
+
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans(kMeansModel.getModelData())
+                        .setDims(2)
+                        .setInitWeights(new Double[] {3., 3.})
+                        .setDecayFactor(0.5)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);
+        configTransformAndSink(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        modelDataSink.poll();
+
+        trainSource.offerAll(trainData2);
+        KMeansModelData actualModelData = modelDataSink.poll();
+
+        KMeansModelData expectedModelData =
+                new KMeansModelData(
+                        new DenseVector[] {
+                            Vectors.dense(10.1, 200.3 / 3), Vectors.dense(-10.2, -200.2 / 3)
+                        });
+
+        Assert.assertEquals(expectedModelData.centroids.length, actualModelData.centroids.length);
+        Arrays.sort(actualModelData.centroids, (o1, o2) -> (int) (o2.values[0] - o1.values[0]));

Review comment:
       Would it be simpler to use `(o1, o2) -> Doubles.compare(o1.values[0], o2.values[0])`?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private String currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = "0";
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void configTransformAndSink(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion =
+                    String.valueOf(
+                            reporter.findMetrics("modelDataVersion").values().stream()
+                                    .map(x -> Integer.parseInt(((Gauge<String>) x).getValue()))
+                                    .min(Integer::compareTo)
+                                    .orElse(Integer.parseInt(currentModelDataVersion)));
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        predictSource.offerAll(OnlineKMeansTest.predictData);
+        List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        OnlineKMeans onlineKMeans = new OnlineKMeans();
+        Assert.assertEquals("features", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", onlineKMeans.getInitMode());

Review comment:
       `onlineKMeans.getInitMode()` is checked twice. Should we remove this one?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private String currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = "0";
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void configTransformAndSink(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) {
+            Thread.yield();

Review comment:
       `Thread.yield()` could lead to busy loop and waste CPU cycles.
   
   How about we use `Thread.sleep(100)`? 

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private String currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = "0";
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void configTransformAndSink(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion =
+                    String.valueOf(
+                            reporter.findMetrics("modelDataVersion").values().stream()
+                                    .map(x -> Integer.parseInt(((Gauge<String>) x).getValue()))
+                                    .min(Integer::compareTo)
+                                    .orElse(Integer.parseInt(currentModelDataVersion)));
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        predictSource.offerAll(OnlineKMeansTest.predictData);
+        List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        OnlineKMeans onlineKMeans = new OnlineKMeans();
+        Assert.assertEquals("features", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", onlineKMeans.getInitMode());
+        Assert.assertEquals(2, onlineKMeans.getK());
+        Assert.assertEquals(1, onlineKMeans.getDims());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(1, onlineKMeans.getBatchSize());
+        Assert.assertEquals(0., onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", onlineKMeans.getInitMode());
+        Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), onlineKMeans.getSeed());
+
+        onlineKMeans
+                .setK(9)

Review comment:
       Should we remove this one?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private String currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = "0";
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void configTransformAndSink(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion =
+                    String.valueOf(
+                            reporter.findMetrics("modelDataVersion").values().stream()
+                                    .map(x -> Integer.parseInt(((Gauge<String>) x).getValue()))
+                                    .min(Integer::compareTo)
+                                    .orElse(Integer.parseInt(currentModelDataVersion)));
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        predictSource.offerAll(OnlineKMeansTest.predictData);
+        List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        OnlineKMeans onlineKMeans = new OnlineKMeans();
+        Assert.assertEquals("features", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", onlineKMeans.getInitMode());
+        Assert.assertEquals(2, onlineKMeans.getK());
+        Assert.assertEquals(1, onlineKMeans.getDims());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(1, onlineKMeans.getBatchSize());
+        Assert.assertEquals(0., onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", onlineKMeans.getInitMode());
+        Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), onlineKMeans.getSeed());
+
+        onlineKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals(3, onlineKMeans.getK());
+        Assert.assertEquals(5, onlineKMeans.getDims());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(5, onlineKMeans.getBatchSize());
+        Assert.assertEquals(0.25, onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", onlineKMeans.getInitMode());
+        Assert.assertEquals(100, onlineKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans()
+                        .setInitMode("random")

Review comment:
       nits: could we use the same order of setXXX(...) across unit tests?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,569 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ *
+ * <p>NOTE: This class's current naive implementation performs the training process in a
+ * single-threaded way. Correctness is not affected but there are performance issues.

Review comment:
       Given that the current implementation supports parallelization, could we remove this statement now?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySinkFunction.java
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SinkFunction} implementation that makes all collected records available for tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySinkFunction<T> extends RichSinkFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+
+    public InMemorySinkFunction() {
+        id = UUID.randomUUID();
+        queue = new LinkedBlockingQueue();
+        queueMap.put(id, queue);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        queue = queueMap.get(id);
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        queueMap.remove(id);
+    }
+
+    @Override
+    public void invoke(T value, Context context) {
+        if (!queue.offer(value)) {
+            throw new RuntimeException(
+                    "Failed to offer " + value + " to blocking queue " + id + ".");
+        }
+    }
+
+    public List<T> poll(int num) throws InterruptedException {

Review comment:
       `poll(...)` typically have a timeout. And it allows returning a null value if there is not sufficient values in the queue before the timeout.
   
   This method currently requires that at least the expected number of values are in the queue, and throw exception otherwise. This semantic seems closer to take().
   
   How about renaming this method and the method below to `take(...)`?
   
   
   
   

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,569 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ *
+ * <p>NOTE: This class's current naive implementation performs the training process in a
+ * single-threaded way. Correctness is not affected but there are performance issues.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK(),
+                        getDims());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dims;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dims) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dims = dims;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dims, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dims;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dims, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dims = dims;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            TypeInformation.of(Integer.class)));

Review comment:
       My understanding is that for basic types like int, we use `BasicTypeInfo.INT_TYPE_INFO`. Not sure if `TypeInformation.of(Integer.class)` would have inferior performance.
   
   It seems simpler to just use `BasicTypeInfo.INT_TYPE_INFO`. Alternatively, please feel free to ask Yun Gao what is the recommended approach.
   
   

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasBatchStrategy.java
##########
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared batch strategy param. */
+@SuppressWarnings("unchecked")
+public interface HasBatchStrategy<T> extends WithParams<T> {
+    String COUNT_STRATEGY = "count";
+
+    Param<String> BATCH_STRATEGY =
+            new StringParam(
+                    "batchStrategy",
+                    "Strategy to create mini batch from online train data.",
+                    COUNT_STRATEGY,
+                    ParamValidators.inArray(COUNT_STRATEGY));
+
+    Param<Integer> BATCH_SIZE =

Review comment:
       There is difference between global batch size and local batch size. `global_batch_size = local_batch_size * num_worker` which is well defined when multiple workers are used in synchronous training, which seems to be the case for `OnlineKMeans`.
   
   Instead of re-defining a global batch size parameter here, how about we re-use `HasGlobalBatchSize`?
   
   
   

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private String currentModelDataVersion;

Review comment:
       Would it be simpler to use `int currentModelDataVersion`?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private String currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = "0";
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void configTransformAndSink(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion =
+                    String.valueOf(
+                            reporter.findMetrics("modelDataVersion").values().stream()
+                                    .map(x -> Integer.parseInt(((Gauge<String>) x).getValue()))
+                                    .min(Integer::compareTo)
+                                    .orElse(Integer.parseInt(currentModelDataVersion)));
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        predictSource.offerAll(OnlineKMeansTest.predictData);
+        List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        OnlineKMeans onlineKMeans = new OnlineKMeans();
+        Assert.assertEquals("features", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", onlineKMeans.getInitMode());
+        Assert.assertEquals(2, onlineKMeans.getK());
+        Assert.assertEquals(1, onlineKMeans.getDims());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(1, onlineKMeans.getBatchSize());
+        Assert.assertEquals(0., onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", onlineKMeans.getInitMode());
+        Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), onlineKMeans.getSeed());
+
+        onlineKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals(3, onlineKMeans.getK());
+        Assert.assertEquals(5, onlineKMeans.getDims());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(5, onlineKMeans.getBatchSize());
+        Assert.assertEquals(0.25, onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", onlineKMeans.getInitMode());
+        Assert.assertEquals(100, onlineKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);
+        configTransformAndSink(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        waitInitModelDataSetup();
+
+        trainSource.offerAll(trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+
+        trainSource.offerAll(trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithKMeans() throws Exception {
+        KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel kMeansModel = kMeans.fit(offlineTrainTable);
+
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans(kMeansModel.getModelData())
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);
+        configTransformAndSink(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+
+        trainSource.offerAll(trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel kMeansModel = kMeans.fit(offlineTrainTable);
+
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans(kMeansModel.getModelData())
+                        .setDims(2)
+                        .setInitWeights(new Double[] {3., 3.})
+                        .setDecayFactor(0.5)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);
+        configTransformAndSink(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        modelDataSink.poll();
+
+        trainSource.offerAll(trainData2);
+        KMeansModelData actualModelData = modelDataSink.poll();
+
+        KMeansModelData expectedModelData =
+                new KMeansModelData(
+                        new DenseVector[] {
+                            Vectors.dense(10.1, 200.3 / 3), Vectors.dense(-10.2, -200.2 / 3)
+                        });
+
+        Assert.assertEquals(expectedModelData.centroids.length, actualModelData.centroids.length);
+        Arrays.sort(actualModelData.centroids, (o1, o2) -> (int) (o2.values[0] - o1.values[0]));
+        for (int i = 0; i < expectedModelData.centroids.length; i++) {
+            Assert.assertArrayEquals(
+                    expectedModelData.centroids[i].values,
+                    actualModelData.centroids[i].values,
+                    1e-5);
+        }
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel kMeansModel = kMeans.fit(offlineTrainTable);
+
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans(kMeansModel.getModelData())
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+
+        String savePath = tempFolder.newFolder().getAbsolutePath();
+        onlineKMeans.save(savePath);
+        miniCluster.executeJobBlocking(env.getStreamGraph().getJobGraph());
+        OnlineKMeans loadedKMeans = OnlineKMeans.load(env, savePath);
+
+        OnlineKMeansModel onlineModel = loadedKMeans.fit(trainTable);

Review comment:
       nits: would it be a bit more readable to rename this variable as `model`, so that its name is more consistent with the `loadedModel` used below?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SourceFunction} implementation that can directly receive records from tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySourceFunction<T> extends RichSourceFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+    private boolean isRunning = true;
+
+    public InMemorySourceFunction() {
+        id = UUID.randomUUID();
+        queue = new LinkedBlockingQueue();
+        queueMap.put(id, queue);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        queue = queueMap.get(id);
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        queueMap.remove(id);
+    }
+
+    @Override
+    public void run(SourceContext<T> context) {
+        while (isRunning) {
+            T value = queue.poll();
+            if (value == null) {
+                Thread.yield();
+            } else {
+                context.collect(value);
+            }
+        }
+    }
+
+    @Override
+    public void cancel() {
+        isRunning = false;
+    }
+
+    @SafeVarargs
+    public final void offerAll(T... values) throws InterruptedException {
+        for (T value : values) {
+            offer(value);
+        }
+    }
+
+    public void offer(T value) throws InterruptedException {
+        offer(value, 1, TimeUnit.MINUTES);

Review comment:
       Given that we don't limit the capacity of this queue, would it be simpler to call `queue.add(value)` here?
   
   If we agree to do this, we will also need to rename/remove those `offer*` methods as appropriate.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832807869



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,569 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ *
+ * <p>NOTE: This class's current naive implementation performs the training process in a
+ * single-threaded way. Correctness is not affected but there are performance issues.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table... initModelDataTables) {

Review comment:
       Yes, that makes better sense for now. I'll make the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#issuecomment-1076331386


   Thanks for the comments. I have updated the PR according to the comments.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832760331



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SourceFunction} implementation that can directly receive records from tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySourceFunction<T> extends RichSourceFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+    private boolean isRunning = true;
+
+    public InMemorySourceFunction() {
+        id = UUID.randomUUID();
+        queue = new LinkedBlockingQueue();
+        queueMap.put(id, queue);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        queue = queueMap.get(id);
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        queueMap.remove(id);
+    }
+
+    @Override
+    public void run(SourceContext<T> context) {
+        while (isRunning) {
+            T value = queue.poll();
+            if (value == null) {
+                Thread.yield();
+            } else {
+                context.collect(value);
+            }
+        }
+    }
+
+    @Override
+    public void cancel() {
+        isRunning = false;
+    }
+
+    @SafeVarargs
+    public final void offerAll(T... values) throws InterruptedException {
+        for (T value : values) {
+            offer(value);
+        }
+    }
+
+    public void offer(T value) throws InterruptedException {
+        offer(value, 1, TimeUnit.MINUTES);

Review comment:
       I agree it would be simpler. I'll make the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832799295



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,409 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;

Review comment:
       Strictly speaking they are different algorithms due to the following reasons:
   - When they are given the same bounded training data, KMeans will iterate the data multiple times and OnlineKMeans will iterate the data exactly once. Thus the algorithm is different. And the resulting model data and prediction accuracy will also be different.
   - They have different behavior in terms of whether they support unbounded training data.
   
   You are right that KMeans and OnlineKMeans implementations share some infra codes. But this seems more like an implementation detail that is more related to developer experience.
   
   After thinking about this more, I don't have a strong opinion here and both approaches work for me. We can keep it as is and separate them into different packages in the future when needed.
   
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832803814



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SourceFunction} implementation that can directly receive records from tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySourceFunction<T> extends RichSourceFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+    private boolean isRunning = true;
+
+    public InMemorySourceFunction() {
+        id = UUID.randomUUID();
+        queue = new LinkedBlockingQueue();
+        queueMap.put(id, queue);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        queue = queueMap.get(id);
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        queueMap.remove(id);
+    }
+
+    @Override
+    public void run(SourceContext<T> context) {
+        while (isRunning) {
+            T value = queue.poll();
+            if (value == null) {
+                Thread.yield();

Review comment:
       Sounds good.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834235458



##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/api/Stage.java
##########
@@ -40,6 +40,6 @@
  */
 @PublicEvolving
 public interface Stage<T extends Stage<T>> extends WithParams<T>, Serializable {
-    /** Saves this stage to the given path. */
+    /** Saves the metadata and bounded model data of this stage to the given path. */

Review comment:
       Could we replace `bounded model data` with `bounded data` so that the description is a bit more general?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,464 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();

Review comment:
       According to `AbstractTestBase`'s Javadoc, it could save significant amount of time to reuse the same mini-cluster for tests.
   
   Could we re-use the same mini-cluster for tests using a similar approach as `AbstractTestBase`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();
+
+        // TODO: replace this simple implementation of model data version with the formal API to
+        // track model version after its design is settled.
+        private int modelDataVersion;

Review comment:
       After the operator is opened and before the operator processes any input value, the value of this variable would already be exposed. It will be better to explicitly define it.
   
   How about we set the default value of this variable to be 0, which could mean `the model version is unknown`? And the valid model versions should start from 1.

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,464 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;

Review comment:
       Given that we already have `offlineTrainTable`, how about renaming this table as `onlineTrainTable`?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java
##########
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SourceFunction} implementation that can directly receive records from tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySourceFunction<T> extends RichSourceFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+    private boolean isRunning = true;
+
+    public InMemorySourceFunction() {
+        id = UUID.randomUUID();
+        queue = new LinkedBlockingQueue();
+        queueMap.put(id, queue);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        queue = queueMap.get(id);
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        queueMap.remove(id);
+    }
+
+    @Override
+    public void run(SourceContext<T> context) throws InterruptedException {
+        while (isRunning) {
+            context.collect(queue.poll(1, TimeUnit.MINUTES));

Review comment:
       The current approach means that, after a graceful shutdown is requested, the source operator might needs to wait up to 1 minute before this operator can actually be shutdown.
   
   How about we use `BlockingQueue<Optional<T>> queue`, let `cancel()` inserts `Optional.empty()` into the queue, and have `run()` ignore empty value?
   

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();
+
+        // TODO: replace this simple implementation of model data version with the formal API to
+        // track model version after its design is settled.
+        private int modelDataVersion;
+
+        public PredictLabelFunction(String featuresCol, DistanceMeasure distanceMeasure) {
+            this.featuresCol = featuresCol;
+            this.distanceMeasure = distanceMeasure;
+        }
+
+        @Override
+        public void open(Configuration parameters) throws Exception {
+            super.open(parameters);
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            "modelDataVersion",

Review comment:
       Could we put the metric name in a static final variable? Then `OnlineKMeansTest` could use the variable instead of copying the string.

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,464 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void transformAndOutputData(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(predictTable)[0];
+        tEnv.toDataStream(outputTable).addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        KMeansModelData.getModelDataStream(modelDataTable).addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() throws InterruptedException {
+        while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) {
+            Thread.sleep(100);
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate() throws InterruptedException {
+        do {
+            int tmpModelDataVersion =
+                    reporter.findMetrics("modelDataVersion").values().stream()
+                            .map(x -> Integer.parseInt(((Gauge<String>) x).getValue()))
+                            .min(Integer::compareTo)
+                            .orElse(currentModelDataVersion);

Review comment:
       Given that `waitModelDataUpdate()` is only called after or within `waitInitModelDataSetup()`, the `orElse(...)` will never be needed, right?
   
   Would it be simpler to replace `orElse(...)` with `get()`?

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java
##########
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A {@link SourceFunction} implementation that can directly receive records from tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class InMemorySourceFunction<T> extends RichSourceFunction<T> {
+    private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>();
+    private final UUID id;
+    private BlockingQueue<T> queue;
+    private boolean isRunning = true;

Review comment:
       Should it be `volatile`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();

Review comment:
       Could this line be removed?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();

Review comment:
       While I believe we will have public API on Table to expose its env, I am not sure we will also have public API on `StreamTableEnvironmentImpl` to expose its StreamExecutionEnvironment. 
   
   It will be better to reduce the use of internal APIs if possible. How about we just pass `StreamExecutionEnvironment` as the method's parameter? This would also be more consistent with `load(...)`.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
##########
@@ -47,8 +47,16 @@
 
     public DenseVector[] centroids;
 
+    public DenseVector weights;
+
+    public KMeansModelData(DenseVector[] centroids, DenseVector weights) {
+        this.centroids = centroids;
+        this.weights = weights;
+    }
+
     public KMeansModelData(DenseVector[] centroids) {
         this.centroids = centroids;
+        this.weights = new DenseVector(centroids.length);

Review comment:
       It seems a bit weird to have weights to be a vector of 0s.
   
   Would it be simpler to remove this constructor and explicitly specify weights in tests? The weights specify easily by e.g. `Vectors.dense(1, 1)`.
   

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {

Review comment:
       Would it be better to still have set `k` as the parameter of `OnlineKMeans`, so that it is as similar to `KMeans` as possible?
   
   The number of centroids is a key aspect of the kmeans algorithm (including its online version) and users would want to know what its value. It would be harder for users to get this information if we don't specify it as OnlineKMeans's parameter.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>

Review comment:
       Could we add Java doc explaining the algorithm used in this operator? Same for `ModelDataGlobalReducer`.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {

Review comment:
       Would it be simpler to remove `throws Exception` here? Same for other methods.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r833278166



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> outputModelData =
+                    modelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class ModelDataGlobalUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final int k;
+        private final int dim;
+        private final int upstreamParallelism;
+        private final double decayFactor;
+
+        private ListState<Integer> partialModelDataReceivingState;
+        private ListState<Boolean> initModelDataReceivingState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        private ModelDataGlobalUpdater(
+                int k, int dim, int upstreamParallelism, double decayFactor) {
+            this.k = k;
+            this.dim = dim;
+            this.upstreamParallelism = upstreamParallelism;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            partialModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "partialModelDataReceiving",
+                                            BasicTypeInfo.INT_TYPE_INFO));
+
+            initModelDataReceivingState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "initModelDataReceiving",
+                                            BasicTypeInfo.BOOLEAN_TYPE_INFO));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+
+            initStateValues();
+        }
+
+        private void initStateValues() throws Exception {
+            partialModelDataReceivingState.update(Collections.singletonList(0));
+            initModelDataReceivingState.update(Collections.singletonList(false));
+            DenseVector[] emptyCentroids = new DenseVector[k];
+            for (int i = 0; i < k; i++) {
+                emptyCentroids[i] = new DenseVector(dim);
+            }
+            modelDataState.update(Collections.singletonList(new KMeansModelData(emptyCentroids)));
+            weightsState.update(Collections.singletonList(new DenseVector(k)));
+        }
+
+        @Override
+        public void processElement1(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> partialModelDataUpdateRecord)
+                throws Exception {
+            int partialModelDataReceiving =
+                    getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving");
+            Preconditions.checkState(partialModelDataReceiving < upstreamParallelism);
+            partialModelDataReceivingState.update(
+                    Collections.singletonList(partialModelDataReceiving + 1));
+            processElement(
+                    partialModelDataUpdateRecord.getValue().f0.centroids,
+                    partialModelDataUpdateRecord.getValue().f1,
+                    1.0);
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> initModelDataRecord)
+                throws Exception {
+            boolean initModelDataReceiving =
+                    getUniqueElement(initModelDataReceivingState, "initModelDataReceiving");
+            Preconditions.checkState(!initModelDataReceiving);
+            initModelDataReceivingState.update(Collections.singletonList(true));
+            processElement(
+                    initModelDataRecord.getValue().f0.centroids,
+                    initModelDataRecord.getValue().f1,
+                    decayFactor);
+        }
+
+        private void processElement(
+                DenseVector[] newCentroids, DenseVector newWeights, double decayFactor)
+                throws Exception {
+            DenseVector weights = getUniqueElement(weightsState, "weights");
+            DenseVector[] centroids = getUniqueElement(modelDataState, "modelData").centroids;
+
+            for (int i = 0; i < k; i++) {
+                newWeights.values[i] *= decayFactor;
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            if (getUniqueElement(initModelDataReceivingState, "initModelDataReceiving")
+                    && getUniqueElement(partialModelDataReceivingState, "partialModelDataReceiving")
+                            >= upstreamParallelism) {
+                output.collect(
+                        new StreamRecord<>(Tuple2.of(new KMeansModelData(centroids), weights)));
+                initStateValues();
+            } else {
+                modelDataState.update(Collections.singletonList(new KMeansModelData(centroids)));
+                weightsState.update(Collections.singletonList(weights));
+            }
+        }
+    }
+
+    private static <T> T getUniqueElement(ListState<T> state, String stateName) throws Exception {
+        T value = OperatorStateUtils.getUniqueElement(state, stateName).orElse(null);
+        return Objects.requireNonNull(value);
+    }
+
+    private static class ModelDataPartialUpdater
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataPartialUpdater(DistanceMeasure distanceMeasure, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            miniBatchState.add(pointsRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(
+                StreamRecord<Tuple2<KMeansModelData, DenseVector>> modelDataAndWeightsRecord)
+                throws Exception {
+            modelDataState.add(modelDataAndWeightsRecord.getValue().f0);
+            processElement();
+        }
+
+        private void processElement() throws Exception {

Review comment:
       `method does not have a guaranteed behavior` is exactly the case here. This method might not compute/emit model data if  e.g. modelDataState is empty.
   
   `alignAndComputeModelData` also works.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r833316743



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,562 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            Preconditions.checkState(initModelDataTable == null);
+            initModelDataStream = createRandomCentroids(env, getDim(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize(),
+                        getK(),
+                        getDim());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+        private final int dim;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure,
+                double decayFactor,
+                int batchSize,
+                int k,
+                int dim) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+            this.dim = dim;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .flatMap(new MiniBatchDistributor(parallelism))
+                            .rebalance()
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataPartialUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataPartialUpdater(distanceMeasure, k))
+                            .setParallelism(parallelism)
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "ModelDataGlobalUpdater",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new ModelDataGlobalUpdater(k, dim, parallelism, decayFactor))

Review comment:
       After thinking about this more, I think we can actually change the algorithm used in `ModelDataLocalUpdater` and `ModelDataGlobalUpdater ` in such a way that `ModelDataGlobalUpdater` only needs to read from `ModelDataLocalUpdater`'s output and still calculate the right result. We might need to change the type of output emitted by `ModelDataLocalUpdater` .
   
   There is the idea. Currently `ModelDataGlobalUpdater` calculates the weight for the first centroid as `weight_from_last_iteration + sum_of_weights_from_local_updater`. We can change `ModelDataLocalUpdater` to emit `weight_from_last_iteration / parallelism + weights_from_local_batch`. Then `ModelDataGlobalUpdater` can derive the weight for the first centroid as `sum_of_outputs_from_local_updater`, which only depends on the output from `ModelDataLocalUpdater`.
   
   This approach introduces more complexity in the operator. But it could make the Flink job simpler and more performant.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836041520



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       Thanks for raising up this discussion!
   I hold a different opinion on this. We probably should also support when the paralleism is smaller than batch size for consistent user experience. But we could allert the user by a warning.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836068872



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       @zhipeng93 Do you think there is any reason that user would want to actually run OnlinkeKMeans with parallelism < batchSize, and if so, could you provide some detail?
   
   And if no, could you explain why we should actually support this?
   
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836069446



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();

Review comment:
       OK. Then let's do it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836067127



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -177,11 +177,20 @@ public void testFewerDistinctPointsThanCluster() {
         KMeans kmeans = new KMeans().setK(2);
         KMeansModel model = kmeans.fit(input);
         Table output = model.transform(input)[0];
-        List<Set<DenseVector>> expectedGroups =
-                Collections.singletonList(Collections.singleton(Vectors.dense(0.0, 0.1)));
-        List<Set<DenseVector>> actualGroups =
-                executeAndCollect(output, kmeans.getFeaturesCol(), kmeans.getPredictionCol());
-        assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+
+        try {

Review comment:
       OK. I'll change `K`'s description to `"The max number of clusters to create."`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836267017



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,498 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    // TODO: creates static mini cluster once for whole test class after dependency upgrades to
+    // Flink 1.15.
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table onlineTrainTable;
+    private Table onlinePredictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.create();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        offlineTrainTable = tEnv.fromDataStream(env.fromElements(trainData1)).as("features");
+        onlineTrainTable =
+                tEnv.fromDataStream(env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE))
+                        .as("features");
+        onlinePredictTable =
+                tEnv.fromDataStream(env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE))
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void transformAndOutputData(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(onlinePredictTable)[0];
+        tEnv.toDataStream(outputTable).addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        KMeansModelData.getModelDataStream(modelDataTable).addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() throws InterruptedException {
+        while (reporter.findMetrics(OnlineKMeansModel.MODEL_DATA_VERSION_GAUGE_KEY).size()
+                < defaultParallelism) {
+            Thread.sleep(100);
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate() throws InterruptedException {
+        do {
+            int tmpModelDataVersion =
+                    reporter.findMetrics(OnlineKMeansModel.MODEL_DATA_VERSION_GAUGE_KEY).values()
+                            .stream()
+                            .map(x -> Integer.parseInt(((Gauge<String>) x).getValue()))
+                            .min(Integer::compareTo)
+                            .get();
+            if (tmpModelDataVersion == currentModelDataVersion) {
+                Thread.sleep(100);
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        predictSource.addAll(OnlineKMeansTest.predictData);
+        List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        OnlineKMeans onlineKMeans = new OnlineKMeans();
+        Assert.assertEquals("features", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure());
+        Assert.assertEquals(32, onlineKMeans.getGlobalBatchSize());
+        Assert.assertEquals(0., onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), onlineKMeans.getSeed());
+
+        onlineKMeans
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setGlobalBatchSize(5)
+                .setDecayFactor(0.25)
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure());
+        Assert.assertEquals(5, onlineKMeans.getGlobalBatchSize());
+        Assert.assertEquals(0.25, onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals(100, onlineKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans()
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setGlobalBatchSize(6)
+                        .setInitialModelData(
+                                KMeansModelData.generateRandomModelData(env, 2, 2, 0.0, 0));
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(onlineTrainTable);
+        transformAndOutputData(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        waitInitModelDataSetup();
+
+        trainSource.addAll(trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+
+        trainSource.addAll(trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithKMeans() throws Exception {
+        KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel model = kMeans.fit(offlineTrainTable);
+
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans()
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setGlobalBatchSize(6)
+                        .setInitialModelData(model.getModelData()[0]);
+
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(onlineTrainTable);
+        transformAndOutputData(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+
+        trainSource.addAll(trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel model = kMeans.fit(offlineTrainTable);
+
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans()
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setGlobalBatchSize(6)
+                        .setDecayFactor(0.5)
+                        .setInitialModelData(model.getModelData()[0]);
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(onlineTrainTable);
+        transformAndOutputData(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        modelDataSink.poll();
+
+        trainSource.addAll(trainData2);
+        KMeansModelData actualModelData = modelDataSink.poll();
+
+        KMeansModelData expectedModelData =
+                new KMeansModelData(
+                        new DenseVector[] {
+                            Vectors.dense(-10.2, -200.2 / 3), Vectors.dense(10.1, 200.3 / 3)
+                        },
+                        Vectors.dense(4.5, 4.5));
+
+        Assert.assertArrayEquals(
+                expectedModelData.weights.values, actualModelData.weights.values, 1e-5);
+        Assert.assertEquals(expectedModelData.centroids.length, actualModelData.centroids.length);
+        Arrays.sort(actualModelData.centroids, Comparator.comparingDouble(vector -> vector.get(0)));
+        for (int i = 0; i < expectedModelData.centroids.length; i++) {
+            Assert.assertArrayEquals(
+                    expectedModelData.centroids[i].values,
+                    actualModelData.centroids[i].values,
+                    1e-5);
+        }
+    }
+
+    @Test
+    public void testFewerPointsThanSubtask() {

Review comment:
       Would it be better to rename this test as `testBatchSizeLessThanParallelism()`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
##########
@@ -47,12 +52,69 @@
 
     public DenseVector[] centroids;
 
-    public KMeansModelData(DenseVector[] centroids) {
+    /**
+     * The weight of the centroids. It is used when updating the model data in online training
+     * process.
+     *
+     * <p>KMeansModelData objects generated during {@link KMeans#fit(Table...)} also contains this
+     * field, so that it can be used as the initial model data of the online training process.
+     */
+    public DenseVector weights;
+
+    public KMeansModelData(DenseVector[] centroids, DenseVector weights) {
+        Preconditions.checkArgument(centroids.length == weights.size());
         this.centroids = centroids;
+        this.weights = weights;
     }
 
     public KMeansModelData() {}
 
+    /**
+     * Generates a Table containing a {@link KMeansModelData} instance with randomly generated
+     * centroids.
+     *
+     * @param env The environment where to create the table.
+     * @param k The number of generated centroids.
+     * @param dim The size of generated centroids.
+     * @param weight The weight of the centroids.
+     * @param seed Random seed.
+     */
+    public static Table generateRandomModelData(
+            StreamExecutionEnvironment env, int k, int dim, double weight, long seed) {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        return tEnv.fromDataStream(
+                env.fromElements(1).map(new RandomCentroidsCreator(k, dim, weight, seed)));
+    }
+
+    private static class RandomCentroidsCreator implements MapFunction<Integer, KMeansModelData> {
+        private final int k;
+        private final int dim;
+        private final long seed;
+        private final double weight;
+
+        private RandomCentroidsCreator(int k, int dim, double weight, long seed) {
+            this.k = k;
+            this.dim = dim;
+            this.seed = seed;

Review comment:
       nits: maybe reorder this line with the line below for consistency with the method's input parameters?
   
   Same for the member private variable declarations :)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r835200989



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    // TODO: creates static mini cluster once for whole test class after dependency upgrades to
+    // Flink 1.15.
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table onlineTrainTable;
+    private Table onlinePredictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();

Review comment:
       Replacing `f0` with `features` seems to throw exception here. I'll directly remove the `schema` variable.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r835209014



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =

Review comment:
       I think making them static is enough. Most test classes in the lib module have been following this practice.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834993187



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();

Review comment:
       You are right.
   
   Functionality-wise, it seems for Flink ML to depend on both Table and DataStream, because Flink runtime already allows Table to be converted back-and-forth from a DataStream.
   
   I do agree that it seems cleaner to consistently require table environment in our public API. Changing this would require us to file a FLIP. This does not cause real problem, I think it is OK to still use StreamExecutionEnvironment for now.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r835702703



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {

Review comment:
       According to offline discussion, I'll add a `KMeansUtils.generateRandomModelData()` method and remove the `setRandomCentroids` method from `OnlineKMeans`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 edited a comment on pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 edited a comment on pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#issuecomment-1081346919


   @lindong28 Thanks for the review. This PR LGTM.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on pull request #70: [FLINK-26313] Support Streaming KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#issuecomment-1057816374


   Hi @lindong28, I have created the PR for the streaming KMeans algorithm. Would you mind help reviewing this PR?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829798468



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -207,25 +211,21 @@ public void testSaveLoadAndPredict() throws Exception {
         KMeansModel loadedModel =
                 StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
         Table output = loadedModel.transform(dataTable)[0];
-        assertEquals(
-                Collections.singletonList("centroids"),
-                loadedModel.getModelData()[0].getResolvedSchema().getColumnNames());
         assertEquals(
                 Arrays.asList("features", "prediction"),
                 output.getResolvedSchema().getColumnNames());
 
+        List<Row> results = IteratorUtils.toList(output.execute().collect());
         List<Set<DenseVector>> actualGroups =
-                executeAndCollect(output, kmeans.getFeaturesCol(), kmeans.getPredictionCol());
+                groupFeaturesByPrediction(
+                        results, kmeans.getFeaturesCol(), kmeans.getPredictionCol());
         assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
     }
 
     @Test
     public void testGetModelData() throws Exception {
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
         KMeansModel model = kmeans.fit(dataTable);
-        assertEquals(
-                Collections.singletonList("centroids"),

Review comment:
       I had once changed the design of `KMeansModelData` to make it contain more than centroids (adding weights field), and this removal is the result of that change. Now I have recovered `KMeansModelData`'s structure, but forgot to add back this check. I'll fix it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r831990504



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {

Review comment:
       Solved as described in the other comment about TestMetricReporter.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832004245



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,511 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.GlobalBlockingQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest {

Review comment:
       Got it. I'll make the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r835771045



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update
+ * model data in a streaming format, using the model data provided by {@link OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();

Review comment:
       The long term solution is to `read first model data from unbounded model data stream before processing the first predict data`. If we can not achieve this goal, the OnlineKMeansModel won't be reliable for production usage anyway due to OOM. So it does not seem helpful to checkpoint the bufferedPoints here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836010275



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       Yes, I think the reason is related to performance.
   
   If `parallelism > batchSize`,  it effectively means some slot (with its CPU resource) is definitely wasted. Is there any reason user would want to do this? If not, it means user must have chosen this setup by mistake. Would it be more user friendly to alert user of this issue by throwing an exception?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834972977



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelParams.java
##########
@@ -21,27 +21,11 @@
 import org.apache.flink.ml.common.param.HasDistanceMeasure;
 import org.apache.flink.ml.common.param.HasFeaturesCol;
 import org.apache.flink.ml.common.param.HasPredictionCol;
-import org.apache.flink.ml.param.IntParam;
-import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.param.ParamValidators;
 
 /**
- * Params of {@link KMeansModel}.
+ * Params of {@link KMeansModel} and {@link OnlineKMeansModel}.
  *
  * @param <T> The class type of this instance.
  */
 public interface KMeansModelParams<T>
-        extends HasDistanceMeasure<T>, HasFeaturesCol<T>, HasPredictionCol<T> {
-
-    Param<Integer> K =

Review comment:
       In `KMeansModel`, `getK()` is never used, so if the input model data's K is different from that of KMeansModel's K parameter, it would still function correctly. Because of this I removed K from `KMeansModelParams`.
   
   Now I agree that users may want to get to know the K of the model operator. But in order to solve the concern above, I'd also like to check the model data centroids length and the value of K parameter in `KMeansModel` and `OnlineKMeansModel`, so as to make sure the K functions correctly.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834989556



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * newWeights.values[i])
+                                    / Math.max(weights.values[i] + newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData")
+                            .orElseThrow((Supplier<Exception>) NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link #setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {

Review comment:
       I tried adding parameters of `DIM` and `INIT_WEIGHT` so as to save them with other metadata, but have not provided their getter/setter methods to still encourage users to use the `setRandomCentroids(dim, weight)` API. Please check if this is a proper design.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] zhipeng93 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836037690



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -177,11 +177,20 @@ public void testFewerDistinctPointsThanCluster() {
         KMeans kmeans = new KMeans().setK(2);
         KMeansModel model = kmeans.fit(input);
         Table output = model.transform(input)[0];
-        List<Set<DenseVector>> expectedGroups =
-                Collections.singletonList(Collections.singleton(Vectors.dense(0.0, 0.1)));
-        List<Set<DenseVector>> actualGroups =
-                executeAndCollect(output, kmeans.getFeaturesCol(), kmeans.getPredictionCol());
-        assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+
+        try {

Review comment:
       I agree with the definition of `The max number of clusters to create...`. 
   
   If there are fewer distinct points than clusters, I would suggest not to create `k` centers by duplicating some data points for the following two reasons:
   - Existing libraries like Spark ML/Alink are not doing this.
   - There is no known use case for making it `k` centers with some identical cluster centers. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r830430626



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();
+
+            List<DenseVector> weightsList = IteratorUtils.toList(weightsState.get().iterator());
+            if (weightsList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + weightsList.size()
+                                + " list of weights in this round");
+            }
+            DenseVector weights = weightsList.get(0);
+            weightsState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);

Review comment:
       I am not sure that the model data and input data would be inputted into this operator in lock-step, when we take into consideration the physical transmission latency of those data.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r830430626



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());
+        } else {
+            initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream =
+                initModelDataStream.map(new InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();
+
+            List<DenseVector> weightsList = IteratorUtils.toList(weightsState.get().iterator());
+            if (weightsList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + weightsList.size()
+                                + " list of weights in this round");
+            }
+            DenseVector weights = weightsList.get(0);
+            weightsState.clear();
+
+            List<DenseVector[]> pointsList = IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);

Review comment:
       I am not sure that this could be true when we take into consideration the physical transmission latency of those data.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Support Online KMeans in Flink ML

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829777954



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed());

Review comment:
       I agree. In fact we only need to check `initModelDataTable == null` when initMode is `random`, as if initMode is `direct` the following implementation would naturally fail unless `initModelDataTable != null`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] lindong28 commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r836106432



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,407 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of points assigned to
+ * them. The weight of the original centroids is also the number of points, but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay
+ * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are
+ * determined entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getK(),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder());
+
+            onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return onlineKMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final int k;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.k = k;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();

Review comment:
       Thanks for the discussion @zhipeng93.
   
   To summarize our offline discussion, we agreed that there is no reasonable reason that users would want to run the training process with `batchStrategy=size` and `parallelism > batchSize`. We will throw exception here to prevent users from wasting resources. And we will not throw this exception if the algorithm is configured with `batchStrategy=time`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r833162477



##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansParams.java
##########
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasDecayFactor;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasSeed;
+import org.apache.flink.ml.param.DoubleArrayParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link OnlineKMeans}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineKMeansParams<T>
+        extends HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasDecayFactor<T>,
+                HasSeed<T>,
+                KMeansModelParams<T> {
+    Param<String> INIT_MODE =
+            new StringParam(
+                    "initMode",
+                    "How to initialize the model data of the online KMeans algorithm. Supported options: 'random', 'direct'.",
+                    "random",
+                    ParamValidators.inArray("random", "direct"));
+
+    Param<Integer> DIM =
+            new IntParam(
+                    "dim",
+                    "The number of dimensions of centroids. Used when initializing random centroids.",
+                    1,
+                    ParamValidators.gt(0));
+
+    Param<Double[]> INIT_WEIGHTS =

Review comment:
       I agree. I'll make the change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [flink-ml] yunfengzhou-hub commented on a change in pull request #70: [FLINK-26313] Add Transformer and Estimator of OnlineKMeans

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832757562



##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,487 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private String currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = "0";
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void configTransformAndSink(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion =
+                    String.valueOf(
+                            reporter.findMetrics("modelDataVersion").values().stream()
+                                    .map(x -> Integer.parseInt(((Gauge<String>) x).getValue()))
+                                    .min(Integer::compareTo)
+                                    .orElse(Integer.parseInt(currentModelDataVersion)));
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the expected group result
+     * @param featuresCol Name of the column in the table that contains the features
+     * @param predictionCol Name of the column in the table that contains the prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol)
+            throws Exception {
+        predictSource.offerAll(OnlineKMeansTest.predictData);
+        List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        OnlineKMeans onlineKMeans = new OnlineKMeans();
+        Assert.assertEquals("features", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", onlineKMeans.getInitMode());
+        Assert.assertEquals(2, onlineKMeans.getK());
+        Assert.assertEquals(1, onlineKMeans.getDims());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(1, onlineKMeans.getBatchSize());
+        Assert.assertEquals(0., onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", onlineKMeans.getInitMode());
+        Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), onlineKMeans.getSeed());
+
+        onlineKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals(3, onlineKMeans.getK());
+        Assert.assertEquals(5, onlineKMeans.getDims());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(5, onlineKMeans.getBatchSize());
+        Assert.assertEquals(0.25, onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", onlineKMeans.getInitMode());
+        Assert.assertEquals(100, onlineKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);
+        configTransformAndSink(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        waitInitModelDataSetup();
+
+        trainSource.offerAll(trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+
+        trainSource.offerAll(trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithKMeans() throws Exception {
+        KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel kMeansModel = kMeans.fit(offlineTrainTable);
+
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans(kMeansModel.getModelData())
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);
+        configTransformAndSink(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+
+        trainSource.offerAll(trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel kMeansModel = kMeans.fit(offlineTrainTable);
+
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans(kMeansModel.getModelData())
+                        .setDims(2)
+                        .setInitWeights(new Double[] {3., 3.})
+                        .setDecayFactor(0.5)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable);
+        configTransformAndSink(onlineModel);
+
+        miniCluster.submitJob(env.getStreamGraph().getJobGraph());
+        modelDataSink.poll();
+
+        trainSource.offerAll(trainData2);
+        KMeansModelData actualModelData = modelDataSink.poll();
+
+        KMeansModelData expectedModelData =
+                new KMeansModelData(
+                        new DenseVector[] {
+                            Vectors.dense(10.1, 200.3 / 3), Vectors.dense(-10.2, -200.2 / 3)
+                        });
+
+        Assert.assertEquals(expectedModelData.centroids.length, actualModelData.centroids.length);
+        Arrays.sort(actualModelData.centroids, (o1, o2) -> (int) (o2.values[0] - o1.values[0]));
+        for (int i = 0; i < expectedModelData.centroids.length; i++) {
+            Assert.assertArrayEquals(
+                    expectedModelData.centroids[i].values,
+                    actualModelData.centroids[i].values,
+                    1e-5);
+        }
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel kMeansModel = kMeans.fit(offlineTrainTable);
+
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans(kMeansModel.getModelData())
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+
+        String savePath = tempFolder.newFolder().getAbsolutePath();
+        onlineKMeans.save(savePath);
+        miniCluster.executeJobBlocking(env.getStreamGraph().getJobGraph());
+        OnlineKMeans loadedKMeans = OnlineKMeans.load(env, savePath);
+
+        OnlineKMeansModel onlineModel = loadedKMeans.fit(trainTable);

Review comment:
       I think the following naming might make the variables more clarified. What do you think?
   - `kMeans`: `KMeans` operator
   - `model`: `KMeansModel` operator
   - `onlineKMeans`: `OnlineKMeans` operator before saving
   - `loadedOnlineKMeans`: `OnlineKMeans` operator after reloading
   - `onlineModel`: `OnlineKMeansModel` operator before saving
   - `loadedOnlineModel`: `OnlineKMeansModel` operator after reloading




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org