You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/04/12 01:31:32 UTC

[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator of Ftrl

yunfengzhou-hub commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r847844186


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*

Review Comment:
   Jira tickets for algorithms that are supposed to be added have all been created in advance. You can find the ticket for FTRL on https://issues.apache.org/jira/secure/RapidBoard.jspa?rapidView=541. The ticket for FTRL is FLINK-20790.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {

Review Comment:
   The test cases are arranged differently from existing practice. Let's add tests that each covers the following situations.
   - tests getting/setting parameters
   - tests the most common fit/transform process.
   - tests save/load.
   - tests getting/setting model data.
   - tests invalid inputs/corner cases.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LinearModelData.java:
##########
@@ -37,60 +38,68 @@
 import java.io.OutputStream;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel}, {@link FtrlModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
-public class LogisticRegressionModelData {
+public class LinearModelData {

Review Comment:
   Could you please illustrate the relationship between FTRL and LogisticRegression, and other algorithms like LinearRegression? I'm not sure why we would like to rename `LogisticRegressionModelData` as `LinearModelData`.
   
   If after discussion we still agree that this renaming is reasonable, it would mean that the model data class neither belongs `logisticregresson` or `ftrl` package. We would need to place classes like this to a neutral package, like something named `common`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LinearModelData.java:
##########
@@ -37,60 +38,68 @@
 import java.io.OutputStream;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel}, {@link FtrlModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
-public class LogisticRegressionModelData {
+public class LinearModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;

Review Comment:
   The versioning mechanism is different from that in `OnlineKMeans`. Shall we adopt the same practice across both online algorithms?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlParams.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Ftrl}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlParams<T>
+        extends HasLabelCol<T>, HasBatchStrategy<T>, HasGlobalBatchSize<T>, HasFeaturesCol<T> {

Review Comment:
   An `Estimator`'s param class should inherit the corresponding `Model`'s param class. In this current implementation, `Ftrl` would not be able to set the `rawPredictionCol` and `predictionCol` of the generated `FtrlModel`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] {value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData map(DenseVector[] value) {
+                                                    return new LinearModelData(value[0], iter++);
+                                                }
+                                            })

Review Comment:
   Shall we move logics like this to a separated class or method? That could make the code look prettier.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private Table trainDenseTable;
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),

Review Comment:
   Shall we make `new double[]{1.0, 1.0, 1.0}` a variable? It might make the code look more pretty.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private Table trainDenseTable;
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {0, 2, 3}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 3, 4}, new double[] {1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 8}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 8}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 2, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 2, 4}, new double[] {1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {7, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 7, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.of(DenseVector.class))
+                        .column("f1", DataTypes.DOUBLE())
+                        .build();
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream, schema).as(FEATURE_COL, LABEL_COL);
+    }
+
+    @Test
+    public void testFtrlWithInitLrModel() throws Exception {
+        Table initModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrl() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrlModel() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+
+        Table onlineTrainTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, PREDICT_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        FtrlModel model =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable);
+    }
+
+    @Test
+    public void testFtrlModelSparse() throws Exception {
+        Table initModelSparse =
+                tEnv.fromDataStream(
+                        env.fromElements(
+                                Row.of(
+                                        new DenseVector(
+                                                new double[] {
+                                                    0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1
+                                                }),
+                                        0L)));
+
+        Table onlineTrainTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, TRAIN_SPARSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, PREDICT_SPARSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        FtrlModel model =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModelSparse)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        tEnv.toDataStream(model.getModelData()[0]).print();
+        verifyPredictionResult(resultTable);
+    }
+
+    private static void verifyPredictionResult(Table output) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+        DataStream<Row> stream = tEnv.toDataStream(output);
+        List<Row> result = IteratorUtils.toList(stream.executeAndCollect());
+        Map<Long, Tuple2<Double, Double>> correctRatio = new HashMap<>();
+
+        for (Row row : result) {
+            long modelVersion = row.getFieldAs(MODEL_VERSION_COL);
+            Double pred = row.getFieldAs(PREDICT_COL);
+            Double label = row.getFieldAs(LABEL_COL);
+            if (correctRatio.containsKey(modelVersion)) {
+                Tuple2<Double, Double> t2 = correctRatio.get(modelVersion);
+                if (pred.equals(label)) {
+                    t2.f0 += 1.0;
+                }
+                t2.f1 += 1.0;
+            } else {
+                correctRatio.put(modelVersion, Tuple2.of(pred.equals(label) ? 1.0 : 0.0, 1.0));
+            }
+        }
+        for (Long id : correctRatio.keySet()) {
+            System.out.println(
+                    id
+                            + " : "
+                            + correctRatio.get(id).f0 / correctRatio.get(id).f1
+                            + " total sample num : "
+                            + correctRatio.get(id).f1);
+            if (id > 0L) {
+                assertEquals(1.0, correctRatio.get(id).f0 / correctRatio.get(id).f1, 1.0e-5);
+            }
+        }
+    }
+
+    /** Generates random data for ftrl train and predict. */
+    public static class RandomSourceFunction implements SourceFunction<Row> {
+        private volatile boolean isRunning = true;
+        private final long timeInterval;
+        private final long maxSize;
+        private final List<Row> data;
+
+        public RandomSourceFunction(long timeInterval, long maxSize, List<Row> data)
+                throws InterruptedException {
+            this.timeInterval = timeInterval;
+            this.maxSize = maxSize;
+            this.data = data;
+        }
+
+        @Override
+        public void run(SourceContext<Row> ctx) throws Exception {
+            int size = data.size();
+            for (int i = 0; i < maxSize; ++i) {
+                if (i == 0) {
+                    Thread.sleep(5000);

Review Comment:
   Shall we avoid using `Thread.sleep()` in test cases? If every unit test adopt this practice, the total time for `mvn install` would be very long.
   
   What we actually want to do here is just to make sure the job has been initialized before any input is provided. `OnlineKMeansTest` has established a best practice for such kind of problem, using `InMemorySourceFunction`, you can refer to these classes for how to solve this problem.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private Table trainDenseTable;
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {0, 2, 3}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 3, 4}, new double[] {1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 8}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 8}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 2, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 2, 4}, new double[] {1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {7, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 7, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.of(DenseVector.class))
+                        .column("f1", DataTypes.DOUBLE())
+                        .build();
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream, schema).as(FEATURE_COL, LABEL_COL);
+    }
+
+    @Test
+    public void testFtrlWithInitLrModel() throws Exception {
+        Table initModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrl() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();

Review Comment:
   Let's examine the correctness of the output of all test cases. `env.execute()` only guarantees that the job does not throw exception during its execution, while it does not mean the calculation result is correct.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlModelParams.java:
##########
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+
+/**
+ * Params for {@link FtrlModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlModelParams<T>
+        extends HasFeaturesCol<T>, HasPredictionCol<T>, HasRawPredictionCol<T> {}

Review Comment:
   This interface is identical to `LogisticRegressionModelParams`. If FTRL is an online version of LogisticRegression, shall we consider reorganizing code to avoid creating duplicate classes like this?



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private Table trainDenseTable;
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {0, 2, 3}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 3, 4}, new double[] {1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 8}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 8}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 2, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 2, 4}, new double[] {1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {7, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 7, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.of(DenseVector.class))
+                        .column("f1", DataTypes.DOUBLE())
+                        .build();
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream, schema).as(FEATURE_COL, LABEL_COL);
+    }
+
+    @Test
+    public void testFtrlWithInitLrModel() throws Exception {
+        Table initModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrl() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrlModel() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+
+        Table onlineTrainTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, PREDICT_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        FtrlModel model =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable);
+    }
+
+    @Test
+    public void testFtrlModelSparse() throws Exception {
+        Table initModelSparse =
+                tEnv.fromDataStream(
+                        env.fromElements(
+                                Row.of(
+                                        new DenseVector(
+                                                new double[] {
+                                                    0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1
+                                                }),
+                                        0L)));
+
+        Table onlineTrainTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, TRAIN_SPARSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, PREDICT_SPARSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        FtrlModel model =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModelSparse)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        tEnv.toDataStream(model.getModelData()[0]).print();
+        verifyPredictionResult(resultTable);
+    }
+
+    private static void verifyPredictionResult(Table output) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+        DataStream<Row> stream = tEnv.toDataStream(output);
+        List<Row> result = IteratorUtils.toList(stream.executeAndCollect());
+        Map<Long, Tuple2<Double, Double>> correctRatio = new HashMap<>();
+
+        for (Row row : result) {
+            long modelVersion = row.getFieldAs(MODEL_VERSION_COL);
+            Double pred = row.getFieldAs(PREDICT_COL);
+            Double label = row.getFieldAs(LABEL_COL);
+            if (correctRatio.containsKey(modelVersion)) {
+                Tuple2<Double, Double> t2 = correctRatio.get(modelVersion);
+                if (pred.equals(label)) {
+                    t2.f0 += 1.0;
+                }
+                t2.f1 += 1.0;
+            } else {
+                correctRatio.put(modelVersion, Tuple2.of(pred.equals(label) ? 1.0 : 0.0, 1.0));
+            }
+        }
+        for (Long id : correctRatio.keySet()) {
+            System.out.println(

Review Comment:
   Let's avoid printing debugging information in test cases.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlParams.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Ftrl}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlParams<T>
+        extends HasLabelCol<T>, HasBatchStrategy<T>, HasGlobalBatchSize<T>, HasFeaturesCol<T> {
+
+    Param<Integer> VECTOR_SIZE =
+            new IntParam("vectorSize", "The size of vector.", -1, ParamValidators.gt(-2));
+
+    default Integer getVectorSize() {

Review Comment:
   This param seems unused.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlParams.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Ftrl}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlParams<T>
+        extends HasLabelCol<T>, HasBatchStrategy<T>, HasGlobalBatchSize<T>, HasFeaturesCol<T> {
+
+    Param<Integer> VECTOR_SIZE =
+            new IntParam("vectorSize", "The size of vector.", -1, ParamValidators.gt(-2));
+
+    default Integer getVectorSize() {
+        return get(VECTOR_SIZE);
+    }
+
+    default T setVectorSize(Integer value) {
+        return set(VECTOR_SIZE, value);
+    }
+
+    Param<Double> L_1 =
+            new DoubleParam("l1", "The parameter l1 of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getL1() {
+        return get(L_1);
+    }
+
+    default T setL1(Double value) {
+        return set(L_1, value);
+    }
+
+    Param<Double> L_2 =
+            new DoubleParam("l2", "The parameter l2 of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getL2() {

Review Comment:
   It seems that these parameters are also used in other algorithms like soft max and multi-layer perception. Let's define them as common parameters.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlParams.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Ftrl}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlParams<T>
+        extends HasLabelCol<T>, HasBatchStrategy<T>, HasGlobalBatchSize<T>, HasFeaturesCol<T> {

Review Comment:
   It seems that some parameters provided by LogisticRegression, like `weightCol`, `reg` and `multiClass`, are not supported by FTRL for now. Is FTRL supposed to have weaker functionality than LogisticRegression, if both used in offline training process? If not, do we have any plan of adding support for these parameters in future?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] {value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData map(DenseVector[] value) {
+                                                    return new LinearModelData(value[0], iter++);
+                                                }
+                                            })
+                                    .setParallelism(1)));
+        }
+    }
+
+    /** Gets vector data. */
+    public static class GetVectorData implements MapFunction<LinearModelData, DenseVector[]> {
+        @Override
+        public DenseVector[] map(LinearModelData value) throws Exception {
+            return new DenseVector[] {value.coefficient};
+        }
+    }
+
+    /**
+     * Operator that collects a LogisticRegressionModelData from each upstream subtask, and outputs
+     * the weight average of collected model data.
+     */
+    public static class FtrlGlobalReducer implements ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] newModelData) {
+            for (int i = 0; i < newModelData[0].size(); ++i) {
+                if ((modelData[1].values[i] + newModelData[1].values[i]) > 0.0) {
+                    newModelData[0].values[i] =
+                            (modelData[0].values[i] * modelData[1].values[i]
+                                            + newModelData[0].values[i] * newModelData[1].values[i])
+                                    / (modelData[1].values[i] + newModelData[1].values[i]);
+                }
+                newModelData[1].values[i] = modelData[1].values[i] + newModelData[1].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    public static class FtrlLocalUpdater extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<
+                    Tuple2<Vector, Double>[], DenseVector[], DenseVector[]> {
+        private ListState<Tuple2<Vector, Double>[]> localBatchDataState;
+        private ListState<DenseVector[]> modelDataState;
+        private double[] n;
+        private double[] z;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private DenseVector weights;
+
+        public FtrlLocalUpdater(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<Tuple2<Vector, Double>[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector[].class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Tuple2<Vector, Double>[]> pointsRecord)
+                throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {

Review Comment:
   Would it be better to add some comments to this method, or divide this method further into several smaller methods? That could help improve readability.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlModel.java:
##########
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which classifies data using the model data computed by {@link Ftrl}. */
+public class FtrlModel implements Model<FtrlModel>, FtrlModelParams<FtrlModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public FtrlModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.DOUBLE, Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(), getPredictionCol(), "modelVersion"));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(LinearModelData.getModelDataStream(modelDataTable).broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, LinearModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            "MODEL_DATA_VERSION_GAUGE_KEY",
+                            (Gauge<String>) () -> Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws Exception {
+            Row dataPoint = streamRecord.getValue();
+            // todo : predict data
+            Vector features = (Vector) dataPoint.getField(featuresCol);
+            Tuple2<Double, DenseVector> predictionResult = predictRaw(features, coefficient);
+            output.collect(
+                    new StreamRecord<>(
+                            Row.join(dataPoint, Row.of(predictionResult.f0, modelDataVersion))));
+        }
+
+        @Override
+        public void processElement2(StreamRecord<LinearModelData> streamRecord) throws Exception {
+            LinearModelData modelData = streamRecord.getValue();
+
+            coefficient = modelData.coefficient;
+            modelDataVersion = modelData.modelVersion;
+            // System.out.println("update model...");
+            // todo : receive model data.
+            // Preconditions.checkArgument(modelData.centroids.length <= k);
+            // centroids = modelData.centroids;
+            // modelDataVersion++;
+            // for (Row dataPoint : bufferedPointsState.get()) {
+            //	processElement1(new StreamRecord<>(dataPoint));
+            // }
+            // bufferedPointsState.clear();

Review Comment:
   Let's reformat the code to remove unused codes.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] {value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData map(DenseVector[] value) {
+                                                    return new LinearModelData(value[0], iter++);
+                                                }
+                                            })
+                                    .setParallelism(1)));
+        }
+    }
+
+    /** Gets vector data. */
+    public static class GetVectorData implements MapFunction<LinearModelData, DenseVector[]> {
+        @Override
+        public DenseVector[] map(LinearModelData value) throws Exception {
+            return new DenseVector[] {value.coefficient};
+        }
+    }
+
+    /**
+     * Operator that collects a LogisticRegressionModelData from each upstream subtask, and outputs
+     * the weight average of collected model data.
+     */
+    public static class FtrlGlobalReducer implements ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] newModelData) {
+            for (int i = 0; i < newModelData[0].size(); ++i) {
+                if ((modelData[1].values[i] + newModelData[1].values[i]) > 0.0) {
+                    newModelData[0].values[i] =
+                            (modelData[0].values[i] * modelData[1].values[i]
+                                            + newModelData[0].values[i] * newModelData[1].values[i])
+                                    / (modelData[1].values[i] + newModelData[1].values[i]);
+                }
+                newModelData[1].values[i] = modelData[1].values[i] + newModelData[1].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    public static class FtrlLocalUpdater extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<
+                    Tuple2<Vector, Double>[], DenseVector[], DenseVector[]> {
+        private ListState<Tuple2<Vector, Double>[]> localBatchDataState;
+        private ListState<DenseVector[]> modelDataState;
+        private double[] n;
+        private double[] z;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private DenseVector weights;
+
+        public FtrlLocalUpdater(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<Tuple2<Vector, Double>[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector[].class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Tuple2<Vector, Double>[]> pointsRecord)
+                throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Tuple2<Vector, Double>[]> pointsList =
+                    IteratorUtils.toList(localBatchDataState.get().iterator());
+            Tuple2<Vector, Double>[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Tuple2<Vector, Double> point : points) {
+                if (n == null) {
+                    n = new double[point.f0.size()];
+                    z = new double[n.length];
+                    weights = new DenseVector(n.length);
+                }
+
+                double p = 0.0;
+                Arrays.fill(weights.values, 0.0);
+                if (point.f0 instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) point.f0;
+                    for (int i = 0; i < denseVector.size(); ++i) {
+                        if (Math.abs(z[i]) <= l1) {
+                            modelData[0].values[i] = 0.0;
+                        } else {
+                            modelData[0].values[i] =
+                                    ((z[i] < 0 ? -1 : 1) * l1 - z[i])
+                                            / ((beta + Math.sqrt(n[i])) / alpha + l2);
+                        }
+                        p += modelData[0].values[i] * denseVector.values[i];
+                    }
+                    p = 1 / (1 + Math.exp(-p));
+                    for (int i = 0; i < denseVector.size(); ++i) {
+                        double g = (p - point.f1) * denseVector.values[i];
+                        double sigma = (Math.sqrt(n[i] + g * g) - Math.sqrt(n[i])) / alpha;
+                        z[i] += g - sigma * modelData[0].values[i];
+                        n[i] += g * g;
+                        weights.values[i] += 1.0;
+                    }
+                } else {
+                    SparseVector sparseVector = (SparseVector) point.f0;
+                    for (int i = 0; i < sparseVector.indices.length; ++i) {
+                        int idx = sparseVector.indices[i];
+                        if (Math.abs(z[idx]) <= l1) {
+                            modelData[0].values[idx] = 0.0;
+                        } else {
+                            modelData[0].values[idx] =
+                                    ((z[idx] < 0 ? -1 : 1) * l1 - z[idx])
+                                            / ((beta + Math.sqrt(n[idx])) / alpha + l2);
+                        }
+                        p += modelData[0].values[idx] * sparseVector.values[i];
+                    }
+                    p = 1 / (1 + Math.exp(-p));
+                    for (int i = 0; i < sparseVector.indices.length; ++i) {
+                        int idx = sparseVector.indices[i];
+                        double g = (p - point.f1) * sparseVector.values[i];
+                        double sigma = (Math.sqrt(n[idx] + g * g) - Math.sqrt(n[idx])) / alpha;
+                        z[idx] += g - sigma * modelData[0].values[idx];
+                        n[idx] += g * g;
+                        weights.values[idx] += 1.0;

Review Comment:
   It seems that weights are only updated but not used.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlModel.java:
##########
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which classifies data using the model data computed by {@link Ftrl}. */
+public class FtrlModel implements Model<FtrlModel>, FtrlModelParams<FtrlModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public FtrlModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.DOUBLE, Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(), getPredictionCol(), "modelVersion"));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(LinearModelData.getModelDataStream(modelDataTable).broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, LinearModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            "MODEL_DATA_VERSION_GAUGE_KEY",
+                            (Gauge<String>) () -> Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws Exception {
+            Row dataPoint = streamRecord.getValue();
+            // todo : predict data
+            Vector features = (Vector) dataPoint.getField(featuresCol);
+            Tuple2<Double, DenseVector> predictionResult = predictRaw(features, coefficient);
+            output.collect(
+                    new StreamRecord<>(
+                            Row.join(dataPoint, Row.of(predictionResult.f0, modelDataVersion))));
+        }
+
+        @Override
+        public void processElement2(StreamRecord<LinearModelData> streamRecord) throws Exception {
+            LinearModelData modelData = streamRecord.getValue();
+
+            coefficient = modelData.coefficient;
+            modelDataVersion = modelData.modelVersion;
+            // System.out.println("update model...");
+            // todo : receive model data.
+            // Preconditions.checkArgument(modelData.centroids.length <= k);
+            // centroids = modelData.centroids;
+            // modelDataVersion++;
+            // for (Row dataPoint : bufferedPointsState.get()) {
+            //	processElement1(new StreamRecord<>(dataPoint));
+            // }
+            // bufferedPointsState.clear();
+        }
+    }
+
+    /**
+     * The main logic that predicts one input record.
+     *
+     * @param feature The input feature.
+     * @param coefficient The model parameters.
+     * @return The prediction label and the raw probabilities.
+     */
+    public static Tuple2<Double, DenseVector> predictRaw(Vector feature, DenseVector coefficient)

Review Comment:
   Methods like this are almost identical to that in `LogisticRegressionModel`. It could be better if we could reuse logics that have already been defined in `LogisticRegression` and `LogisticRegressionModel`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] {value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData map(DenseVector[] value) {
+                                                    return new LinearModelData(value[0], iter++);
+                                                }
+                                            })
+                                    .setParallelism(1)));
+        }
+    }
+
+    /** Gets vector data. */
+    public static class GetVectorData implements MapFunction<LinearModelData, DenseVector[]> {
+        @Override
+        public DenseVector[] map(LinearModelData value) throws Exception {
+            return new DenseVector[] {value.coefficient};
+        }
+    }
+
+    /**
+     * Operator that collects a LogisticRegressionModelData from each upstream subtask, and outputs
+     * the weight average of collected model data.
+     */
+    public static class FtrlGlobalReducer implements ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] newModelData) {
+            for (int i = 0; i < newModelData[0].size(); ++i) {
+                if ((modelData[1].values[i] + newModelData[1].values[i]) > 0.0) {
+                    newModelData[0].values[i] =
+                            (modelData[0].values[i] * modelData[1].values[i]
+                                            + newModelData[0].values[i] * newModelData[1].values[i])
+                                    / (modelData[1].values[i] + newModelData[1].values[i]);
+                }
+                newModelData[1].values[i] = modelData[1].values[i] + newModelData[1].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    public static class FtrlLocalUpdater extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<
+                    Tuple2<Vector, Double>[], DenseVector[], DenseVector[]> {
+        private ListState<Tuple2<Vector, Double>[]> localBatchDataState;
+        private ListState<DenseVector[]> modelDataState;
+        private double[] n;
+        private double[] z;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private DenseVector weights;
+
+        public FtrlLocalUpdater(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<Tuple2<Vector, Double>[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector[].class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Tuple2<Vector, Double>[]> pointsRecord)
+                throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Tuple2<Vector, Double>[]> pointsList =
+                    IteratorUtils.toList(localBatchDataState.get().iterator());
+            Tuple2<Vector, Double>[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Tuple2<Vector, Double> point : points) {
+                if (n == null) {
+                    n = new double[point.f0.size()];
+                    z = new double[n.length];
+                    weights = new DenseVector(n.length);
+                }
+
+                double p = 0.0;
+                Arrays.fill(weights.values, 0.0);
+                if (point.f0 instanceof DenseVector) {

Review Comment:
   It seems that some logic in `if` and in `else` are the same. Shall we move these codes out of the if-else condition?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] {value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData map(DenseVector[] value) {
+                                                    return new LinearModelData(value[0], iter++);
+                                                }
+                                            })
+                                    .setParallelism(1)));
+        }
+    }
+
+    /** Gets vector data. */
+    public static class GetVectorData implements MapFunction<LinearModelData, DenseVector[]> {
+        @Override
+        public DenseVector[] map(LinearModelData value) throws Exception {
+            return new DenseVector[] {value.coefficient};
+        }
+    }
+
+    /**
+     * Operator that collects a LogisticRegressionModelData from each upstream subtask, and outputs
+     * the weight average of collected model data.
+     */
+    public static class FtrlGlobalReducer implements ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] newModelData) {
+            for (int i = 0; i < newModelData[0].size(); ++i) {
+                if ((modelData[1].values[i] + newModelData[1].values[i]) > 0.0) {
+                    newModelData[0].values[i] =
+                            (modelData[0].values[i] * modelData[1].values[i]
+                                            + newModelData[0].values[i] * newModelData[1].values[i])
+                                    / (modelData[1].values[i] + newModelData[1].values[i]);
+                }
+                newModelData[1].values[i] = modelData[1].values[i] + newModelData[1].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    public static class FtrlLocalUpdater extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<
+                    Tuple2<Vector, Double>[], DenseVector[], DenseVector[]> {
+        private ListState<Tuple2<Vector, Double>[]> localBatchDataState;
+        private ListState<DenseVector[]> modelDataState;
+        private double[] n;
+        private double[] z;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private DenseVector weights;
+
+        public FtrlLocalUpdater(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<Tuple2<Vector, Double>[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector[].class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Tuple2<Vector, Double>[]> pointsRecord)
+                throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Tuple2<Vector, Double>[]> pointsList =
+                    IteratorUtils.toList(localBatchDataState.get().iterator());
+            Tuple2<Vector, Double>[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Tuple2<Vector, Double> point : points) {
+                if (n == null) {
+                    n = new double[point.f0.size()];
+                    z = new double[n.length];
+                    weights = new DenseVector(n.length);
+                }
+
+                double p = 0.0;
+                Arrays.fill(weights.values, 0.0);
+                if (point.f0 instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) point.f0;
+                    for (int i = 0; i < denseVector.size(); ++i) {
+                        if (Math.abs(z[i]) <= l1) {
+                            modelData[0].values[i] = 0.0;
+                        } else {
+                            modelData[0].values[i] =
+                                    ((z[i] < 0 ? -1 : 1) * l1 - z[i])
+                                            / ((beta + Math.sqrt(n[i])) / alpha + l2);
+                        }
+                        p += modelData[0].values[i] * denseVector.values[i];
+                    }
+                    p = 1 / (1 + Math.exp(-p));
+                    for (int i = 0; i < denseVector.size(); ++i) {
+                        double g = (p - point.f1) * denseVector.values[i];
+                        double sigma = (Math.sqrt(n[i] + g * g) - Math.sqrt(n[i])) / alpha;
+                        z[i] += g - sigma * modelData[0].values[i];
+                        n[i] += g * g;
+                        weights.values[i] += 1.0;
+                    }
+                } else {
+                    SparseVector sparseVector = (SparseVector) point.f0;
+                    for (int i = 0; i < sparseVector.indices.length; ++i) {
+                        int idx = sparseVector.indices[i];
+                        if (Math.abs(z[idx]) <= l1) {
+                            modelData[0].values[idx] = 0.0;
+                        } else {
+                            modelData[0].values[idx] =
+                                    ((z[idx] < 0 ? -1 : 1) * l1 - z[idx])
+                                            / ((beta + Math.sqrt(n[idx])) / alpha + l2);
+                        }
+                        p += modelData[0].values[idx] * sparseVector.values[i];
+                    }
+                    p = 1 / (1 + Math.exp(-p));
+                    for (int i = 0; i < sparseVector.indices.length; ++i) {
+                        int idx = sparseVector.indices[i];
+                        double g = (p - point.f1) * sparseVector.values[i];
+                        double sigma = (Math.sqrt(n[idx] + g * g) - Math.sqrt(n[idx])) / alpha;
+                        z[idx] += g - sigma * modelData[0].values[idx];
+                        n[idx] += g * g;
+                        weights.values[idx] += 1.0;
+                    }
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector[] {modelData[0], weights}));
+        }
+    }
+
+    /** Parses samples of input data. */
+    public static class ParseSample extends RichMapFunction<Row, Tuple2<Vector, Double>> {
+        private static final long serialVersionUID = 3738888745125082777L;

Review Comment:
   Is variables like this required?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org