You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/04/21 07:27:29 UTC

[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

zhipeng93 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r854745613


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModelParams.java:
##########
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params for {@link OnlineLogisticRegressionModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionModelParams<T>
+        extends HasFeaturesCol<T>, HasPredictionCol<T>, HasRawPredictionCol<T> {
+    Param<String> MODEL_VERSION_COL =

Review Comment:
   Could you explain the reason of introducing `MODEL_VERSION_COL` param here? It is not consistent with the implementation of `OnlineKmeansModel`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasL1;
+import org.apache.flink.ml.common.param.HasL2;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasL1<T>,
+                HasL2<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =

Review Comment:
   I am a bit torn of whether introducing the `HasOptimMethod` interface. Since OnlineLogisticRegression could be implementated via online methods like SGD and FTRL.
   
   So basically there are two possible solutions:
   - introduce `HasOptimMethod` and parameters for different optimizers.
   - Rename the class as `OnlineLogisticRegressionWithFtrl`
   
   I think option-1 is a better solution for the long run but definitely needs more design and eigineering effort. Shall we introduce the param now and let other optimizers unimplemented? 
   
   I would like to also hear more from others.
   
   



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the OnlineLogisticRegression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Online_machine_learning.

Review Comment:
   Could you add more java docs about online logistic regression? Also, could you please update the link here? It seems that it is a general introduction for online learning.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the OnlineLogisticRegression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Online_machine_learning.
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector[]>)
+                                value -> new DenseVector[] {value.coefficient});
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBETA(),
+                        getL1(),
+                        getL2(),
+                        getFeaturesCol(),
+                        getLabelCol());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData),
+                                DataStreamList.of(tEnv.toDataStream(inputs[0])),
+                                body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl algorithm. */

Review Comment:
   It would be nice to if we have more introduction of FTRL method.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the OnlineLogisticRegression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Online_machine_learning.
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector[]>)
+                                value -> new DenseVector[] {value.coefficient});
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBETA(),
+                        getL1(),
+                        getL2(),
+                        getFeaturesCol(),
+                        getLabelCol());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData),
+                                DataStreamList.of(tEnv.toDataStream(inputs[0])),
+                                body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl algorithm. */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private final String featureCol;
+        private final String labelCol;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double l1,
+                double l2,
+                String featureCol,
+                String labelCol) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+            this.featureCol = featureCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new ModelDataLocalUpdater(
+                                            alpha, beta, l1, l2, featureCol, labelCol))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+            DataStream<DenseVector[]> feedbackModelData =
+                    newModelData
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> {
+                                                double[] z = value[1].values;
+                                                double[] n = value[2].values;
+                                                for (int i = 0; i < z.length; ++i) {
+                                                    if (Math.abs(z[i]) <= l1) {
+                                                        value[0].values[i] = 0.0;
+                                                    } else {
+                                                        value[0].values[i] =
+                                                                ((z[i] < 0 ? -1 : 1) * l1 - z[i])
+                                                                        / ((beta + Math.sqrt(n[i]))
+                                                                                        / alpha
+                                                                                + l2);
+                                                    }
+                                                }
+                                                return new DenseVector[] {value[0]};
+                                            })
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .map(
+                                    (MapFunction<DenseVector[], LogisticRegressionModelData>)
+                                            value ->
+                                                    new LogisticRegressionModelData(
+                                                            value[0], System.nanoTime()))

Review Comment:
   Can we add the model version when the infra for model version is ready?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which classifies data using the model data computed by {@link OnlineLogisticRegression}.
+ */
+public class OnlineLogisticRegressionModel
+        implements Model<OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionModelParams<OnlineLogisticRegressionModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineLogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                Types.DOUBLE,
+                                TypeInformation.of(DenseVector.class),
+                                Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol(),
+                                getModelVersionCol()));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(
+                                LogisticRegressionModelData.getModelDataStream(modelDataTable)
+                                        .broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, LogisticRegressionModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            "MODEL_DATA_VERSION_GAUGE_KEY",
+                            (Gauge<String>) () -> Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws Exception {
+            Row dataPoint = streamRecord.getValue();
+            while (coefficient == null) {

Review Comment:
   nit: could be `if`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java:
##########
@@ -391,7 +391,8 @@ public void onIterationTerminated(Context context, Collector<double[]> collector
             feedbackBufferState.clear();
             if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                 updateModel();
-                context.output(modelDataOutputTag, new LogisticRegressionModelData(coefficient));
+                context.output(
+                        modelDataOutputTag, new LogisticRegressionModelData(coefficient, 0L));

Review Comment:
   The infra for online model (i.e., model version, barrier for model data) is not added yet. Can we add the change when the infra is ready?
   
   @yunfengzhou-hub What do you think ?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasL1;
+import org.apache.flink.ml.common.param.HasL2;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasL1<T>,
+                HasL2<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =
+            new DoubleParam("alpha", "The parameter alpha of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getAlpha() {
+        return get(ALPHA);
+    }
+
+    default T setAlpha(Double value) {
+        return set(ALPHA, value);
+    }
+
+    Param<Double> BETA =
+            new DoubleParam("alpha", "The parameter beta of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getBETA() {

Review Comment:
   The naming of `getBETA` is not consistent with `getAlpha`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasL2.java:
##########
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared L2 param. */
+public interface HasL2<T> extends WithParams<T> {
+    Param<Double> L_2 = new DoubleParam("l2", "The l2 param.", 0.1, ParamValidators.gt(0.0));

Review Comment:
   Should we reuse the existing `HasReg` Param in Logistic Regression or we rename `HasReg` as `HasL2`?
   
   Note that existing libraries handle this differently:
   - Spark uses `HasElasticNetParam` and `HasRegParam` to decide L1 and L2 [1]
   - Sklearn uses `penalty`, which could be l1, l2, elastic [2]
   - Alink uses HasL2 and HasL2 [3]
   
   [1] https://github.com/apache/spark/blob/4dc12eb54544a12ff7ddf078ca8bcec9471212c3/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala#L53
   [2] https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
   [3] https://github.com/alibaba/Alink/blob/master/core/src/main/java/com/alibaba/alink/params/classification/LinearBinaryClassTrainParams.java



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -37,17 +37,19 @@
 import java.io.OutputStream;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;

Review Comment:
   The infra for online model (i.e., model version, barrier for model data) is not added yet. Can we add the change when the infra is ready?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which classifies data using the model data computed by {@link OnlineLogisticRegression}.
+ */
+public class OnlineLogisticRegressionModel
+        implements Model<OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionModelParams<OnlineLogisticRegressionModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineLogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                Types.DOUBLE,
+                                TypeInformation.of(DenseVector.class),
+                                Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol(),
+                                getModelVersionCol()));

Review Comment:
   The behavior here is inconsistent with `OnlineKmeansModel`. We probably need to discuss about whether it is necessary to output the model version.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which classifies data using the model data computed by {@link OnlineLogisticRegression}.
+ */
+public class OnlineLogisticRegressionModel
+        implements Model<OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionModelParams<OnlineLogisticRegressionModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineLogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                Types.DOUBLE,
+                                TypeInformation.of(DenseVector.class),
+                                Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol(),
+                                getModelVersionCol()));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(
+                                LogisticRegressionModelData.getModelDataStream(modelDataTable)
+                                        .broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, LogisticRegressionModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            "MODEL_DATA_VERSION_GAUGE_KEY",
+                            (Gauge<String>) () -> Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws Exception {
+            Row dataPoint = streamRecord.getValue();
+            while (coefficient == null) {
+                bufferedPointsState.add(dataPoint);
+                return;
+            }
+            Vector features = (Vector) dataPoint.getField(featuresCol);
+            Tuple2<Double, DenseVector> predictionResult = predictRaw(features, coefficient);
+            output.collect(
+                    new StreamRecord<>(
+                            Row.join(
+                                    dataPoint,
+                                    Row.of(
+                                            predictionResult.f0,
+                                            predictionResult.f1,
+                                            modelDataVersion))));
+        }
+
+        @Override
+        public void processElement2(StreamRecord<LogisticRegressionModelData> streamRecord)
+                throws Exception {
+            LogisticRegressionModelData modelData = streamRecord.getValue();
+
+            coefficient = modelData.coefficient;
+            modelDataVersion = modelData.modelVersion;
+
+            for (Row dataPoint : bufferedPointsState.get()) {
+                processElement1(new StreamRecord<>(dataPoint));
+            }
+            bufferedPointsState.clear();
+        }
+    }
+
+    /**
+     * The main logic that predicts one input record.
+     *
+     * @param feature The input feature.
+     * @param coefficient The model parameters.
+     * @return The prediction label and the raw probabilities.
+     */
+    private static Tuple2<Double, DenseVector> predictRaw(Vector feature, DenseVector coefficient) {

Review Comment:
   This logic is the same as that in `LogisticRegression`. Should we reuse the logic there?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the OnlineLogisticRegression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Online_machine_learning.
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector[]>)
+                                value -> new DenseVector[] {value.coefficient});
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBETA(),
+                        getL1(),
+                        getL2(),
+                        getFeaturesCol(),
+                        getLabelCol());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData),
+                                DataStreamList.of(tEnv.toDataStream(inputs[0])),
+                                body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl algorithm. */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private final String featureCol;
+        private final String labelCol;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double l1,
+                double l2,
+                String featureCol,
+                String labelCol) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+            this.featureCol = featureCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new ModelDataLocalUpdater(
+                                            alpha, beta, l1, l2, featureCol, labelCol))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+            DataStream<DenseVector[]> feedbackModelData =
+                    newModelData
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> {
+                                                double[] z = value[1].values;
+                                                double[] n = value[2].values;
+                                                for (int i = 0; i < z.length; ++i) {
+                                                    if (Math.abs(z[i]) <= l1) {
+                                                        value[0].values[i] = 0.0;
+                                                    } else {
+                                                        value[0].values[i] =
+                                                                ((z[i] < 0 ? -1 : 1) * l1 - z[i])
+                                                                        / ((beta + Math.sqrt(n[i]))
+                                                                                        / alpha
+                                                                                + l2);
+                                                    }
+                                                }
+                                                return new DenseVector[] {value[0]};
+                                            })
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .map(
+                                    (MapFunction<DenseVector[], LogisticRegressionModelData>)
+                                            value ->
+                                                    new LogisticRegressionModelData(
+                                                            value[0], System.nanoTime()))
+                            .setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData));
+        }
+    }
+
+    /**
+     * Operator that collects a OnlineLogisticRegressionModelData from each upstream subtask, and
+     * outputs the weight average of collected model data.
+     */
+    private static class ModelDataGlobalReducer implements ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] newModelData) {
+
+            for (int i = 0; i < newModelData[1].size(); ++i) {
+                newModelData[1].values[i] = modelData[1].values[i] + newModelData[1].values[i];
+                newModelData[2].values[i] = modelData[2].values[i] + newModelData[2].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<Row[], DenseVector[], DenseVector[]> {
+        private ListState<DenseVector[]> modelDataState;
+        private ListState<Row[]> localBatchDataState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private final String featureCol;
+        private final String labelCol;
+        private double[] weights;
+        private double[] nParam;
+        private double[] zParam;
+        private int[] denseVectorIndices;
+
+        public ModelDataLocalUpdater(
+                double alpha,
+                double beta,
+                double l1,
+                double l2,
+                String featureCol,
+                String labelCol) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+            this.featureCol = featureCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector[].class));
+            TypeInformation<Row[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class));
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row[]> pointsRecord) throws Exception {
+            if (!modelDataState.get().iterator().hasNext()) {
+                localBatchDataState.add(pointsRecord.getValue());
+                return;
+            } else {
+                for (Row[] dataPoints : localBatchDataState.get()) {
+                    updateModel(dataPoints);
+                }
+                localBatchDataState.clear();
+            }
+            updateModel(pointsRecord.getValue());
+        }
+
+        private void updateModel(Row[] points) throws Exception {
+            DenseVector[] modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+
+            for (Row point : points) {
+                Vector vec = point.getFieldAs(featureCol);
+                double label = point.getFieldAs(labelCol);
+                if (nParam == null) {
+                    nParam = new double[vec.size()];
+                    zParam = new double[nParam.length];
+                    weights = new double[nParam.length];
+                    if (vec instanceof DenseVector) {
+                        denseVectorIndices = new int[vec.size()];
+                        for (int i = 0; i < denseVectorIndices.length; ++i) {
+                            denseVectorIndices[i] = i;
+                        }
+                    }
+                }
+
+                System.arraycopy(modelData[0].values, 0, weights, 0, weights.length);
+                int[] indices;
+                double[] values;
+                if (vec instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) vec;
+                    indices = denseVectorIndices;
+                    values = denseVector.values;
+                } else {
+                    SparseVector sparseVector = (SparseVector) vec;
+                    indices = sparseVector.indices;
+                    values = sparseVector.values;
+                }
+                double p = 0.0;
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    if (Math.abs(zParam[idx]) <= l1) {
+                        weights[idx] = 0.0;
+                    } else {
+                        weights[idx] =
+                                ((zParam[idx] < 0 ? -1 : 1) * l1 - zParam[idx])
+                                        / ((beta + Math.sqrt(nParam[idx])) / alpha + l2);
+                    }
+                    p += weights[idx] * values[i];
+                }
+                p = 1 / (1 + Math.exp(-p));
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    double g = (p - label) * values[i];
+                    double sigma =
+                            (Math.sqrt(nParam[idx] + g * g) - Math.sqrt(nParam[idx])) / alpha;
+                    zParam[idx] += g - sigma * weights[idx];
+                    nParam[idx] += g * g;
+                    weights[idx] += 1.0;
+                }
+            }
+            if (points.length > 0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new DenseVector[] {
+                                    new DenseVector(weights),
+                                    new DenseVector(zParam),
+                                    new DenseVector(nParam)
+                                }));
+            }
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> modelDataRecord) throws Exception {
+            modelDataState.clear();
+            modelDataState.add(modelDataRecord.getValue());
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable),
+                path,
+                new LogisticRegressionModelData.ModelDataEncoder());
+    }
+
+    public static OnlineLogisticRegression load(StreamTableEnvironment tEnv, String path)
+            throws IOException {
+        OnlineLogisticRegression onlineLogisticRegression = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(
+                        tEnv, path, new LogisticRegressionModelData.ModelDataDecoder());
+        onlineLogisticRegression.setInitialModelData(modelDataTable);
+        return onlineLogisticRegression;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * An operator that splits a global batch into evenly-sized local batches, and distributes them
+     * to downstream operator.
+     */
+    private static class GlobalBatchSplitter implements FlatMapFunction<Row[], Row[]> {

Review Comment:
   Can we reuse the code snippet in `OnlineKmeans`, rather than copy-paste it here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org