You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/04/11 08:40:13 UTC

[GitHub] [flink-ml] weibozhao opened a new pull request, #83: [FLINK-27170] Add Transformer and Estimator of Ftrl

weibozhao opened a new pull request, #83:
URL: https://github.com/apache/flink-ml/pull/83

   Add Transformer and Estimator of Ftrl


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r884801285


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);

Review Comment:
   modelVersion has been write to state. If restarting happened, algorithm will read the model version from checkpoint. Just as your example, if restarting happened at record 301, the model version data (without model version) is read from checkpoint in CalculateLocalGradient() and the modelVersion (modelVersion=3) is read from checkpoint in the CreateLrModelData() before sink model out. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r882477151


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -18,40 +18,85 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.serialization.Encoder;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
+import java.util.Random;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;
 
-    public LogisticRegressionModelData(DenseVector coefficient) {
+    public LogisticRegressionModelData(DenseVector coefficient, long modelVersion) {
         this.coefficient = coefficient;
+        this.modelVersion = modelVersion;
     }
 
     public LogisticRegressionModelData() {}
 
+    public LogisticRegressionModelData(DenseVector coefficient) {
+        this.coefficient = coefficient;
+        this.modelVersion = 0L;

Review Comment:
   I think, leave a TODO is a better choice.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r854901440


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasL1;
+import org.apache.flink.ml.common.param.HasL2;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasL1<T>,
+                HasL2<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =

Review Comment:
   OK,I will add this param.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r854745613


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModelParams.java:
##########
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params for {@link OnlineLogisticRegressionModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionModelParams<T>
+        extends HasFeaturesCol<T>, HasPredictionCol<T>, HasRawPredictionCol<T> {
+    Param<String> MODEL_VERSION_COL =

Review Comment:
   Could you explain the reason of introducing `MODEL_VERSION_COL` param here? It is not consistent with the implementation of `OnlineKmeansModel`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasL1;
+import org.apache.flink.ml.common.param.HasL2;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasL1<T>,
+                HasL2<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =

Review Comment:
   I am a bit torn of whether introducing the `HasOptimMethod` interface. Since OnlineLogisticRegression could be implementated via online methods like SGD and FTRL.
   
   So basically there are two possible solutions:
   - introduce `HasOptimMethod` and parameters for different optimizers.
   - Rename the class as `OnlineLogisticRegressionWithFtrl`
   
   I think option-1 is a better solution for the long run but definitely needs more design and eigineering effort. Shall we introduce the param now and let other optimizers unimplemented? 
   
   I would like to also hear more from others.
   
   



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the OnlineLogisticRegression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Online_machine_learning.

Review Comment:
   Could you add more java docs about online logistic regression? Also, could you please update the link here? It seems that it is a general introduction for online learning.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the OnlineLogisticRegression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Online_machine_learning.
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector[]>)
+                                value -> new DenseVector[] {value.coefficient});
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBETA(),
+                        getL1(),
+                        getL2(),
+                        getFeaturesCol(),
+                        getLabelCol());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData),
+                                DataStreamList.of(tEnv.toDataStream(inputs[0])),
+                                body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl algorithm. */

Review Comment:
   It would be nice to if we have more introduction of FTRL method.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the OnlineLogisticRegression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Online_machine_learning.
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector[]>)
+                                value -> new DenseVector[] {value.coefficient});
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBETA(),
+                        getL1(),
+                        getL2(),
+                        getFeaturesCol(),
+                        getLabelCol());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData),
+                                DataStreamList.of(tEnv.toDataStream(inputs[0])),
+                                body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl algorithm. */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private final String featureCol;
+        private final String labelCol;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double l1,
+                double l2,
+                String featureCol,
+                String labelCol) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+            this.featureCol = featureCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new ModelDataLocalUpdater(
+                                            alpha, beta, l1, l2, featureCol, labelCol))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+            DataStream<DenseVector[]> feedbackModelData =
+                    newModelData
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> {
+                                                double[] z = value[1].values;
+                                                double[] n = value[2].values;
+                                                for (int i = 0; i < z.length; ++i) {
+                                                    if (Math.abs(z[i]) <= l1) {
+                                                        value[0].values[i] = 0.0;
+                                                    } else {
+                                                        value[0].values[i] =
+                                                                ((z[i] < 0 ? -1 : 1) * l1 - z[i])
+                                                                        / ((beta + Math.sqrt(n[i]))
+                                                                                        / alpha
+                                                                                + l2);
+                                                    }
+                                                }
+                                                return new DenseVector[] {value[0]};
+                                            })
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .map(
+                                    (MapFunction<DenseVector[], LogisticRegressionModelData>)
+                                            value ->
+                                                    new LogisticRegressionModelData(
+                                                            value[0], System.nanoTime()))

Review Comment:
   Can we add the model version when the infra for model version is ready?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which classifies data using the model data computed by {@link OnlineLogisticRegression}.
+ */
+public class OnlineLogisticRegressionModel
+        implements Model<OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionModelParams<OnlineLogisticRegressionModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineLogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                Types.DOUBLE,
+                                TypeInformation.of(DenseVector.class),
+                                Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol(),
+                                getModelVersionCol()));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(
+                                LogisticRegressionModelData.getModelDataStream(modelDataTable)
+                                        .broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, LogisticRegressionModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            "MODEL_DATA_VERSION_GAUGE_KEY",
+                            (Gauge<String>) () -> Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws Exception {
+            Row dataPoint = streamRecord.getValue();
+            while (coefficient == null) {

Review Comment:
   nit: could be `if`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java:
##########
@@ -391,7 +391,8 @@ public void onIterationTerminated(Context context, Collector<double[]> collector
             feedbackBufferState.clear();
             if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                 updateModel();
-                context.output(modelDataOutputTag, new LogisticRegressionModelData(coefficient));
+                context.output(
+                        modelDataOutputTag, new LogisticRegressionModelData(coefficient, 0L));

Review Comment:
   The infra for online model (i.e., model version, barrier for model data) is not added yet. Can we add the change when the infra is ready?
   
   @yunfengzhou-hub What do you think ?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasL1;
+import org.apache.flink.ml.common.param.HasL2;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasL1<T>,
+                HasL2<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =
+            new DoubleParam("alpha", "The parameter alpha of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getAlpha() {
+        return get(ALPHA);
+    }
+
+    default T setAlpha(Double value) {
+        return set(ALPHA, value);
+    }
+
+    Param<Double> BETA =
+            new DoubleParam("alpha", "The parameter beta of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getBETA() {

Review Comment:
   The naming of `getBETA` is not consistent with `getAlpha`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasL2.java:
##########
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared L2 param. */
+public interface HasL2<T> extends WithParams<T> {
+    Param<Double> L_2 = new DoubleParam("l2", "The l2 param.", 0.1, ParamValidators.gt(0.0));

Review Comment:
   Should we reuse the existing `HasReg` Param in Logistic Regression or we rename `HasReg` as `HasL2`?
   
   Note that existing libraries handle this differently:
   - Spark uses `HasElasticNetParam` and `HasRegParam` to decide L1 and L2 [1]
   - Sklearn uses `penalty`, which could be l1, l2, elastic [2]
   - Alink uses HasL2 and HasL2 [3]
   
   [1] https://github.com/apache/spark/blob/4dc12eb54544a12ff7ddf078ca8bcec9471212c3/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala#L53
   [2] https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
   [3] https://github.com/alibaba/Alink/blob/master/core/src/main/java/com/alibaba/alink/params/classification/LinearBinaryClassTrainParams.java



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -37,17 +37,19 @@
 import java.io.OutputStream;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;

Review Comment:
   The infra for online model (i.e., model version, barrier for model data) is not added yet. Can we add the change when the infra is ready?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which classifies data using the model data computed by {@link OnlineLogisticRegression}.
+ */
+public class OnlineLogisticRegressionModel
+        implements Model<OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionModelParams<OnlineLogisticRegressionModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineLogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                Types.DOUBLE,
+                                TypeInformation.of(DenseVector.class),
+                                Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol(),
+                                getModelVersionCol()));

Review Comment:
   The behavior here is inconsistent with `OnlineKmeansModel`. We probably need to discuss about whether it is necessary to output the model version.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which classifies data using the model data computed by {@link OnlineLogisticRegression}.
+ */
+public class OnlineLogisticRegressionModel
+        implements Model<OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionModelParams<OnlineLogisticRegressionModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineLogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                Types.DOUBLE,
+                                TypeInformation.of(DenseVector.class),
+                                Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol(),
+                                getModelVersionCol()));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(
+                                LogisticRegressionModelData.getModelDataStream(modelDataTable)
+                                        .broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, LogisticRegressionModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            "MODEL_DATA_VERSION_GAUGE_KEY",
+                            (Gauge<String>) () -> Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws Exception {
+            Row dataPoint = streamRecord.getValue();
+            while (coefficient == null) {
+                bufferedPointsState.add(dataPoint);
+                return;
+            }
+            Vector features = (Vector) dataPoint.getField(featuresCol);
+            Tuple2<Double, DenseVector> predictionResult = predictRaw(features, coefficient);
+            output.collect(
+                    new StreamRecord<>(
+                            Row.join(
+                                    dataPoint,
+                                    Row.of(
+                                            predictionResult.f0,
+                                            predictionResult.f1,
+                                            modelDataVersion))));
+        }
+
+        @Override
+        public void processElement2(StreamRecord<LogisticRegressionModelData> streamRecord)
+                throws Exception {
+            LogisticRegressionModelData modelData = streamRecord.getValue();
+
+            coefficient = modelData.coefficient;
+            modelDataVersion = modelData.modelVersion;
+
+            for (Row dataPoint : bufferedPointsState.get()) {
+                processElement1(new StreamRecord<>(dataPoint));
+            }
+            bufferedPointsState.clear();
+        }
+    }
+
+    /**
+     * The main logic that predicts one input record.
+     *
+     * @param feature The input feature.
+     * @param coefficient The model parameters.
+     * @return The prediction label and the raw probabilities.
+     */
+    private static Tuple2<Double, DenseVector> predictRaw(Vector feature, DenseVector coefficient) {

Review Comment:
   This logic is the same as that in `LogisticRegression`. Should we reuse the logic there?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the OnlineLogisticRegression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Online_machine_learning.
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector[]>)
+                                value -> new DenseVector[] {value.coefficient});
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBETA(),
+                        getL1(),
+                        getL2(),
+                        getFeaturesCol(),
+                        getLabelCol());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData),
+                                DataStreamList.of(tEnv.toDataStream(inputs[0])),
+                                body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl algorithm. */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private final String featureCol;
+        private final String labelCol;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double l1,
+                double l2,
+                String featureCol,
+                String labelCol) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+            this.featureCol = featureCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new ModelDataLocalUpdater(
+                                            alpha, beta, l1, l2, featureCol, labelCol))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+            DataStream<DenseVector[]> feedbackModelData =
+                    newModelData
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> {
+                                                double[] z = value[1].values;
+                                                double[] n = value[2].values;
+                                                for (int i = 0; i < z.length; ++i) {
+                                                    if (Math.abs(z[i]) <= l1) {
+                                                        value[0].values[i] = 0.0;
+                                                    } else {
+                                                        value[0].values[i] =
+                                                                ((z[i] < 0 ? -1 : 1) * l1 - z[i])
+                                                                        / ((beta + Math.sqrt(n[i]))
+                                                                                        / alpha
+                                                                                + l2);
+                                                    }
+                                                }
+                                                return new DenseVector[] {value[0]};
+                                            })
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .map(
+                                    (MapFunction<DenseVector[], LogisticRegressionModelData>)
+                                            value ->
+                                                    new LogisticRegressionModelData(
+                                                            value[0], System.nanoTime()))
+                            .setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData));
+        }
+    }
+
+    /**
+     * Operator that collects a OnlineLogisticRegressionModelData from each upstream subtask, and
+     * outputs the weight average of collected model data.
+     */
+    private static class ModelDataGlobalReducer implements ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] newModelData) {
+
+            for (int i = 0; i < newModelData[1].size(); ++i) {
+                newModelData[1].values[i] = modelData[1].values[i] + newModelData[1].values[i];
+                newModelData[2].values[i] = modelData[2].values[i] + newModelData[2].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    private static class ModelDataLocalUpdater extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<Row[], DenseVector[], DenseVector[]> {
+        private ListState<DenseVector[]> modelDataState;
+        private ListState<Row[]> localBatchDataState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private final String featureCol;
+        private final String labelCol;
+        private double[] weights;
+        private double[] nParam;
+        private double[] zParam;
+        private int[] denseVectorIndices;
+
+        public ModelDataLocalUpdater(
+                double alpha,
+                double beta,
+                double l1,
+                double l2,
+                String featureCol,
+                String labelCol) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+            this.featureCol = featureCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector[].class));
+            TypeInformation<Row[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class));
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row[]> pointsRecord) throws Exception {
+            if (!modelDataState.get().iterator().hasNext()) {
+                localBatchDataState.add(pointsRecord.getValue());
+                return;
+            } else {
+                for (Row[] dataPoints : localBatchDataState.get()) {
+                    updateModel(dataPoints);
+                }
+                localBatchDataState.clear();
+            }
+            updateModel(pointsRecord.getValue());
+        }
+
+        private void updateModel(Row[] points) throws Exception {
+            DenseVector[] modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+
+            for (Row point : points) {
+                Vector vec = point.getFieldAs(featureCol);
+                double label = point.getFieldAs(labelCol);
+                if (nParam == null) {
+                    nParam = new double[vec.size()];
+                    zParam = new double[nParam.length];
+                    weights = new double[nParam.length];
+                    if (vec instanceof DenseVector) {
+                        denseVectorIndices = new int[vec.size()];
+                        for (int i = 0; i < denseVectorIndices.length; ++i) {
+                            denseVectorIndices[i] = i;
+                        }
+                    }
+                }
+
+                System.arraycopy(modelData[0].values, 0, weights, 0, weights.length);
+                int[] indices;
+                double[] values;
+                if (vec instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) vec;
+                    indices = denseVectorIndices;
+                    values = denseVector.values;
+                } else {
+                    SparseVector sparseVector = (SparseVector) vec;
+                    indices = sparseVector.indices;
+                    values = sparseVector.values;
+                }
+                double p = 0.0;
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    if (Math.abs(zParam[idx]) <= l1) {
+                        weights[idx] = 0.0;
+                    } else {
+                        weights[idx] =
+                                ((zParam[idx] < 0 ? -1 : 1) * l1 - zParam[idx])
+                                        / ((beta + Math.sqrt(nParam[idx])) / alpha + l2);
+                    }
+                    p += weights[idx] * values[i];
+                }
+                p = 1 / (1 + Math.exp(-p));
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    double g = (p - label) * values[i];
+                    double sigma =
+                            (Math.sqrt(nParam[idx] + g * g) - Math.sqrt(nParam[idx])) / alpha;
+                    zParam[idx] += g - sigma * weights[idx];
+                    nParam[idx] += g * g;
+                    weights[idx] += 1.0;
+                }
+            }
+            if (points.length > 0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new DenseVector[] {
+                                    new DenseVector(weights),
+                                    new DenseVector(zParam),
+                                    new DenseVector(nParam)
+                                }));
+            }
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> modelDataRecord) throws Exception {
+            modelDataState.clear();
+            modelDataState.add(modelDataRecord.getValue());
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable),
+                path,
+                new LogisticRegressionModelData.ModelDataEncoder());
+    }
+
+    public static OnlineLogisticRegression load(StreamTableEnvironment tEnv, String path)
+            throws IOException {
+        OnlineLogisticRegression onlineLogisticRegression = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(
+                        tEnv, path, new LogisticRegressionModelData.ModelDataDecoder());
+        onlineLogisticRegression.setInitialModelData(modelDataTable);
+        return onlineLogisticRegression;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * An operator that splits a global batch into evenly-sized local batches, and distributes them
+     * to downstream operator.
+     */
+    private static class GlobalBatchSplitter implements FlatMapFunction<Row[], Row[]> {

Review Comment:
   Can we reuse the code snippet in `OnlineKmeans`, rather than copy-paste it here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r882471104


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>

Review Comment:
   I will add this param later.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r882484808


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), getElasticNet());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     *
+     * <p>todo: makes ftrl to be a common optimizer and place it in org.apache.flink.ml.common in
+     * future.
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(
+                int batchSize, double alpha, double beta, double reg, double elasticNet) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData.map(new CreateLrModelData()).setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class CreateLrModelData
+            implements MapFunction<DenseVector, LogisticRegressionModelData>, CheckpointedFunction {
+        private Long modelVersion = 1L;
+        private transient ListState<Long> modelVersionState;
+
+        @Override
+        public LogisticRegressionModelData map(DenseVector denseVector) throws Exception {
+            return new LogisticRegressionModelData(denseVector, modelVersion++);
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext functionSnapshotContext)
+                throws Exception {
+            modelVersionState.update(Collections.singletonList(modelVersion));
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) throws Exception {
+            modelVersionState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelVersionState", Long.class));
+        }
+    }
+
+    /** Updates model. */
+    private static class UpdateModel extends AbstractStreamOperator<DenseVector>
+            implements OneInputStreamOperator<DenseVector[], DenseVector> {
+        private ListState<double[]> nParamState;
+        private ListState<double[]> zParamState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private double[] nParam;
+        private double[] zParam;
+
+        public UpdateModel(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            nParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("nParamState", double[].class));
+            zParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("zParamState", double[].class));
+        }
+
+        @Override
+        public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            DenseVector[] gradientInfo = streamRecord.getValue();
+            double[] coefficient = gradientInfo[2].values;
+            double[] g = gradientInfo[0].values;
+            for (int i = 0; i < g.length; ++i) {
+                if (gradientInfo[1].values[i] != 0.0) {
+                    g[i] = g[i] / gradientInfo[1].values[i];
+                }
+            }
+            if (zParam == null) {
+                zParam = new double[g.length];
+                nParam = new double[g.length];
+                nParamState.add(nParam);
+                zParamState.add(zParam);
+            }
+
+            for (int i = 0; i < zParam.length; ++i) {
+                double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - Math.sqrt(nParam[i])) / alpha;
+                zParam[i] += g[i] - sigma * coefficient[i];
+                nParam[i] += g[i] * g[i];
+
+                if (Math.abs(zParam[i]) <= l1) {
+                    coefficient[i] = 0.0;
+                } else {
+                    coefficient[i] =
+                            ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i])
+                                    / ((beta + Math.sqrt(nParam[i])) / alpha + l2);
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector(coefficient)));
+        }
+    }
+
+    private static class CalculateLocalGradient extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<Row[], DenseVector, DenseVector[]> {
+        private ListState<DenseVector> modelDataState;
+        private ListState<Row[]> localBatchDataState;
+        private double[] gradient;
+        private double[] weight;
+        private int[] denseVectorIndices;
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector.class));
+            TypeInformation<Row[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class));
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row[]> pointsRecord) throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            calculateGradient();
+        }
+
+        private void calculateGradient() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+            DenseVector modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Row[]> pointsList = IteratorUtils.toList(localBatchDataState.get().iterator());
+            Row[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Row point : points) {
+                Vector vec = point.getFieldAs(0);
+                double label = point.getFieldAs(1);
+                if (gradient == null) {
+                    gradient = new double[vec.size()];
+                    weight = new double[gradient.length];
+                    if (vec instanceof DenseVector) {
+                        denseVectorIndices = new int[vec.size()];
+                        for (int i = 0; i < denseVectorIndices.length; ++i) {
+                            denseVectorIndices[i] = i;
+                        }
+                    }
+                }
+
+                int[] indices;
+                double[] values;
+                if (vec instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) vec;
+                    indices = denseVectorIndices;
+                    values = denseVector.values;
+                } else {
+                    SparseVector sparseVector = (SparseVector) vec;
+                    indices = sparseVector.indices;
+                    values = sparseVector.values;
+                }
+                double p = 0.0;
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    p += modelData.values[idx] * values[i];
+                }
+                p = 1 / (1 + Math.exp(-p));
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    gradient[idx] += (p - label) * values[i];
+                    weight[idx] += 1.0;
+                }
+            }
+
+            if (points.length > 0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new DenseVector[] {
+                                    new DenseVector(gradient),
+                                    new DenseVector(weight),

Review Comment:
   The `weight`  will be renamed as `weightSum`. For as above comment I add weightCol param, this variable is just the sum of weight. 
   
   I will add java doc for the code reduce the gradients.
   
   In my code, the using of model data is different from SGD. In online lr, the model data is used in two places: calculating gradient locally and updating model serially. But in SGD, the model updating is different from online lr. If I change the model  updating as SGD, a lot of code maybe rewrite, but not get an obvious benefits. So I will keep the model output format.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r881284890


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasReg<T>,
+                HasElasticNet<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =
+            new DoubleParam("alpha", "The parameter alpha of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getAlpha() {
+        return get(ALPHA);
+    }
+
+    default T setAlpha(Double value) {
+        return set(ALPHA, value);
+    }
+
+    Param<Double> BETA =
+            new DoubleParam("alpha", "The parameter beta of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getBeta() {
+        return get(BETA);
+    }
+
+    default T setBeta(Double value) {
+        return set(BETA, value);
+    }
+
+    Param<Integer> MODEL_SAVE_INTERVAL =
+            new IntParam(
+                    "modelSaveInterval",
+                    "The iteration steps between two output models.",

Review Comment:
   Maybe, we can discus it with @zhipeng93 this afternoon.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r883628874


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +193,63 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A function that generate several data batches and distribute them to downstream operator.
+     *
+     * @param <T> Data type of batch data.
+     */
+    public static <T> DataStream<T[]> generateBatchData(

Review Comment:
   The current function name and the Java doc do not seem to capture the key functionality of this method, e.g. split the input data into global batches of `batchSize`, where each global batch is further split into `downStreamParallelism` local batches for downstream operators.
   
   And previous code that uses `GlobalBatchSplitter` seems a bit more readable than the current version, which puts everything into one method with deeper indentation.
   
   Could you split the preserve the classes `GlobalBatchSplitter`, `GlobalBatchCreator`, and updates the method name and its Java doc to more it a bit more self-explanatory?



##########
flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java:
##########
@@ -70,7 +70,18 @@ public void testAxpyK() {
     @Test
     public void testDot() {
         DenseVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5);
+        SparseVector sparseVector1 =
+                Vectors.sparse(5, new int[] {1, 2, 4}, new double[] {1., 1., 4.});
+        SparseVector sparseVector2 =
+                Vectors.sparse(5, new int[] {1, 3, 4}, new double[] {1., 2., 1.});
+        // Tests Dot(dense, dense).

Review Comment:
   nits: Since the method name is `dot(...)`, would it be more intuitive to use `dot(dense, dense)` here?
   
   Same for the lines below.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java:
##########
@@ -161,7 +162,8 @@ public Row map(Row dataPoint) {
      * @param coefficient The model parameters.
      * @return The prediction label and the raw probabilities.
      */
-    private static Row predictOneDataPoint(DenseVector feature, DenseVector coefficient) {
+    protected static Row predictOneDataPoint(Vector feature, DenseVector coefficient) {
+

Review Comment:
   Is this empty line necessary?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasWeightCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasReg<T>,
+                HasElasticNet<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =
+            new DoubleParam("alpha", "The parameter alpha of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getAlpha() {
+        return get(ALPHA);
+    }
+
+    default T setAlpha(Double value) {
+        return set(ALPHA, value);
+    }
+
+    Param<Double> BETA =
+            new DoubleParam("alpha", "The parameter beta of ftrl.", 0.1, ParamValidators.gt(0.0));

Review Comment:
   Hmm.. how would users know what is FTRL when they read this doc?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -18,40 +18,86 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.serialization.Encoder;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
+import java.util.Random;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;
 
-    public LogisticRegressionModelData(DenseVector coefficient) {
+    /* todo : modelVersion is a new idea to manage model, other algorithm if needed should add this parameter later. */
+    public LogisticRegressionModelData(DenseVector coefficient, long modelVersion) {
         this.coefficient = coefficient;
+        this.modelVersion = modelVersion;
     }
 
     public LogisticRegressionModelData() {}
 
+    public LogisticRegressionModelData(DenseVector coefficient) {

Review Comment:
   There is only one place that calls this constructor.
   
   Instead of adding this constructor for the specific case where modelVersion=0, would it be simpler to update the caller code to the following code, so that this class is simpler and more consistent with other model classes?
   
   ```
   DataStream<LogisticRegressionModelData> modelData =
           rawModelData.map(vector -> new LogisticRegressionModelData(vector, 0));
   ```



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -18,40 +18,86 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.serialization.Encoder;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
+import java.util.Random;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;
 
-    public LogisticRegressionModelData(DenseVector coefficient) {
+    /* todo : modelVersion is a new idea to manage model, other algorithm if needed should add this parameter later. */

Review Comment:
   It is not clear what is the actionable item for this TODO.
   
   Since `LogisticRegressionModelData` already has `modelVersion`, is this TODO still needed?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java:
##########
@@ -65,12 +65,54 @@ public static void hDot(Vector x, Vector y) {
         }
     }
 
-    /** x \cdot y . */
-    public static double dot(DenseVector x, DenseVector y) {
+    /** Computes the dot of the two vectors (y = y \dot x). */

Review Comment:
   The Java doc seems incorrect since this method actually does not update `y`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the online logistic regression algorithm. */

Review Comment:
   Should we provide reference/link to the original paper, so that users could know what is the algorithm and why it is useful?
   
   Feel free to see KMeans.scala in Spark ML for example doc.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -18,40 +18,86 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.serialization.Encoder;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
+import java.util.Random;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;
 
-    public LogisticRegressionModelData(DenseVector coefficient) {
+    /* todo : modelVersion is a new idea to manage model, other algorithm if needed should add this parameter later. */
+    public LogisticRegressionModelData(DenseVector coefficient, long modelVersion) {
         this.coefficient = coefficient;
+        this.modelVersion = modelVersion;
     }
 
     public LogisticRegressionModelData() {}
 
+    public LogisticRegressionModelData(DenseVector coefficient) {
+        this.coefficient = coefficient;
+        this.modelVersion = 0L;
+    }
+
+    /**
+     * Generates a Table containing a {@link LogisticRegressionModelData} instance with randomly
+     * generated coefficient.
+     *
+     * @param tEnv The environment where to create the table.
+     * @param dim The size of generated coefficient.
+     * @param seed Random seed.
+     */
+    public static Table generateRandomModelData(StreamTableEnvironment tEnv, int dim, int seed) {
+        StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
+        return tEnv.fromDataStream(env.fromElements(1).map(new GenerateRandomModel(dim, seed)));
+    }
+
+    private static class GenerateRandomModel

Review Comment:
   We typically use noun instead of verb as the class name. How about `RandomModelDataGenerator`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasWeightCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasReg<T>,
+                HasElasticNet<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =
+            new DoubleParam("alpha", "The parameter alpha of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getAlpha() {
+        return get(ALPHA);
+    }
+
+    default T setAlpha(Double value) {
+        return set(ALPHA, value);
+    }
+
+    Param<Double> BETA =

Review Comment:
   Since we typically declare variables before methods, could this variable be moved above `getAlpha()`?
   
   Feel free to see `StandardScalerParams` for example.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r854745613


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModelParams.java:
##########
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params for {@link OnlineLogisticRegressionModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionModelParams<T>
+        extends HasFeaturesCol<T>, HasPredictionCol<T>, HasRawPredictionCol<T> {
+    Param<String> MODEL_VERSION_COL =

Review Comment:
   Could you explain the reason of introducing `MODEL_VERSION_COL` param here? It is not consistent with the implementation of `OnlineKmeansModel`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r867573941


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java:
##########
@@ -0,0 +1,444 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
+import org.apache.flink.ml.feature.MinMaxScalerTest;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */
+public class OnlineLogisticRegressionTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainDenseTable;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private static final double[] ONE_ARRAY = new double[] {1.0, 1.0, 1.0};
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 3}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {7, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.));
+
+    private Table initDenseModel;
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream).as(FEATURE_COL, LABEL_COL);
+        initDenseModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+    }
+
+    @Test
+    public void testFit() throws Exception {
+        Table onlinePredictTable = getTable(1, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testFitWithInitLrModel() throws Exception {
+        Table initLrModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+        Table onlineTrainTable = getTable(50, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initLrModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testDenseFitAndPredict() throws Exception {
+        Table onlineTrainTable = getTable(2, 2000, TRAIN_DENSE_ROWS, 2, false);
+        Table onlinePredictTable = getTable(2, 3000, PREDICT_DENSE_ROWS, 2, false);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);
+    }
+
+    @Test
+    public void testSparseFitAndPredict() throws Exception {
+        double[] doubleArray = new double[] {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
+        Table initSparseModel =
+                tEnv.fromDataStream(env.fromElements(Row.of(new DenseVector(doubleArray), 0L)));
+        Table onlineTrainTable = getTable(5, 2000, TRAIN_SPARSE_ROWS, 4, true);
+        Table onlinePredictTable = getTable(5, 3000, PREDICT_SPARSE_ROWS, 4, true);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initSparseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);
+    }
+
+    @Test
+    public void testParam() {
+        OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression();
+        Assert.assertEquals("features", onlineLogisticRegression.getFeaturesCol());
+        Assert.assertEquals("count", onlineLogisticRegression.getBatchStrategy());
+        Assert.assertEquals("label", onlineLogisticRegression.getLabelCol());
+        Assert.assertEquals(0.1, onlineLogisticRegression.getL1(), 1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getL2(), 1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getAlpha(), 1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getBeta(), 1.0e-5);
+        Assert.assertEquals(32, onlineLogisticRegression.getGlobalBatchSize());
+
+        onlineLogisticRegression
+                .setFeaturesCol("test_feature")
+                .setLabelCol("test_label")
+                .setGlobalBatchSize(5)
+                .setL1(0.25)
+                .setL2(0.25)
+                .setAlpha(0.25)
+                .setBeta(0.25);
+
+        Assert.assertEquals("test_feature", onlineLogisticRegression.getFeaturesCol());
+        Assert.assertEquals("test_label", onlineLogisticRegression.getLabelCol());
+        Assert.assertEquals(0.25, onlineLogisticRegression.getL1(), 1.0e-5);
+        Assert.assertEquals(0.25, onlineLogisticRegression.getL2(), 1.0e-5);
+        Assert.assertEquals(0.25, onlineLogisticRegression.getAlpha(), 1.0e-5);
+        Assert.assertEquals(0.25, onlineLogisticRegression.getBeta(), 1.0e-5);
+        Assert.assertEquals(5, onlineLogisticRegression.getGlobalBatchSize());
+
+        OnlineLogisticRegressionModel onlineLogisticRegressionModel =
+                new OnlineLogisticRegressionModel();
+        Assert.assertEquals("features", onlineLogisticRegressionModel.getFeaturesCol());
+        Assert.assertEquals("modelVersion", onlineLogisticRegressionModel.getModelVersionCol());
+        Assert.assertEquals("prediction", onlineLogisticRegressionModel.getPredictionCol());
+        Assert.assertEquals("rawPrediction", onlineLogisticRegressionModel.getRawPredictionCol());
+
+        onlineLogisticRegressionModel
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("pred")
+                .setModelVersionCol("version")
+                .setRawPredictionCol("raw");
+
+        Assert.assertEquals("test_feature", onlineLogisticRegressionModel.getFeaturesCol());
+        Assert.assertEquals("version", onlineLogisticRegressionModel.getModelVersionCol());
+        Assert.assertEquals("pred", onlineLogisticRegressionModel.getPredictionCol());
+        Assert.assertEquals("raw", onlineLogisticRegressionModel.getRawPredictionCol());
+    }
+
+    @Test
+    public void testBatchSizeLessThanParallelism() throws Exception {
+        Table onlinePredictTable = getTable(1, 20, TRAIN_DENSE_ROWS, 4, true);
+        try {
+            new OnlineLogisticRegression()
+                    .setFeaturesCol(FEATURE_COL)
+                    .setInitialModelData(initDenseModel)
+                    .setGlobalBatchSize(2)
+                    .setLabelCol(LABEL_COL)
+                    .fit(onlinePredictTable);
+            Assert.fail("Expected IllegalStateException");
+        } catch (Exception e) {
+            Throwable exception = e;
+            while (exception.getCause() != null) {
+                exception = exception.getCause();
+            }
+            Assert.assertEquals(IllegalStateException.class, exception.getClass());
+            Assert.assertEquals(
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.",
+                    exception.getMessage());
+        }
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        Table onlineTrainTable = getTable(2, 2000, TRAIN_DENSE_ROWS, 2, false);
+        Table onlinePredictTable = getTable(2, 3000, PREDICT_DENSE_ROWS, 2, false);
+        OnlineLogisticRegression onlineLogisticRegression =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL);
+        OnlineLogisticRegression loadedOnlineLogisticRegression =
+                StageTestUtils.saveAndReload(
+                        tEnv, onlineLogisticRegression, tempFolder.newFolder().getAbsolutePath());
+        OnlineLogisticRegressionModel onlineLogisticRegressionModel =
+                loadedOnlineLogisticRegression.fit(onlineTrainTable);
+        Table resultTable =
+                onlineLogisticRegressionModel.setPredictionCol(PREDICT_COL)
+                        .transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);
+        OnlineLogisticRegressionModel loadedOnlineLogisticRegressionModel =
+                StageTestUtils.saveAndReload(
+                        tEnv,
+                        onlineLogisticRegressionModel,
+                        tempFolder.newFolder().getAbsolutePath());
+        Table resultTableWithLoadedModel =
+                loadedOnlineLogisticRegressionModel.setPredictionCol(PREDICT_COL)
+                        .transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTableWithLoadedModel, -1);
+    }
+
+    @Test
+    public void testGetModelData() throws Exception {
+        Table onlineTrainTable = getTable(2, 2000, TRAIN_DENSE_ROWS, 1, false);
+        List<DenseVector> expected =
+                new ArrayList<>(
+                        Arrays.asList(
+                                Vectors.dense(0.6094293839451556, -0.3535110997464949),
+                                Vectors.dense(0.8817781161262602, -0.6045148530476719),
+                                Vectors.dense(1.0802504223028735, -0.7809336961447708),
+                                Vectors.dense(1.236292181150552, -0.9166121469926248)));
+        OnlineLogisticRegression onlineLogisticRegression =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL);
+        OnlineLogisticRegressionModel onlineLogisticRegressionModel =
+                onlineLogisticRegression.fit(onlineTrainTable);
+        Table modelData = onlineLogisticRegressionModel.getModelData()[0];
+        DataStream<DenseVector> dataStream =
+                tEnv.toDataStream(modelData)
+                        .map((MapFunction<Row, DenseVector>) value -> value.getFieldAs(0));
+        List<DenseVector> result = IteratorUtils.toList(dataStream.executeAndCollect());
+        result.sort((o1, o2) -> MinMaxScalerTest.compare(o1, o2));
+        Assert.assertEquals(expected, result);
+    }
+
+    @Test
+    public void testSetModelData() throws Exception {
+        Table onlinePredictTable = getTable(1, 1, TRAIN_DENSE_ROWS, 1, false);
+
+        OnlineLogisticRegressionModel onlineLogisticRegressionModel =
+                new OnlineLogisticRegressionModel().setFeaturesCol(FEATURE_COL);
+
+        Table modelData =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.5, 0.1}), 0L)));
+        onlineLogisticRegressionModel.setModelData(modelData);
+
+        Row expected =
+                Row.of(
+                        Vectors.dense(0.1, 2.0),
+                        0.0,
+                        1.0,
+                        Vectors.dense(0.43782349911420193, 0.5621765008857981),
+                        0L);
+        DataStream<Row> results =
+                tEnv.toDataStream(onlineLogisticRegressionModel.transform(onlinePredictTable)[0]);
+        List<Row> resultList = IteratorUtils.toList(results.executeAndCollect());
+        Assert.assertEquals(expected, resultList.get(0));
+    }
+
+    private static void verifyPredictionResult(Table output, int expectedNum) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+        DataStream<Row> stream = tEnv.toDataStream(output);
+        List<Row> result = IteratorUtils.toList(stream.executeAndCollect());
+        Map<Long, Tuple2<Double, Double>> correctRatio = new HashMap<>();
+
+        for (Row row : result) {
+            long modelVersion = row.getFieldAs(MODEL_VERSION_COL);
+            Double pred = row.getFieldAs(PREDICT_COL);
+            Double label = row.getFieldAs(LABEL_COL);
+            if (correctRatio.containsKey(modelVersion)) {
+                Tuple2<Double, Double> t2 = correctRatio.get(modelVersion);
+                if (pred.equals(label)) {
+                    t2.f0 += 1.0;
+                }
+                t2.f1 += 1.0;
+            } else {
+                correctRatio.put(modelVersion, Tuple2.of(pred.equals(label) ? 1.0 : 0.0, 1.0));
+            }
+        }
+        int numModel = 0;
+        for (Long id : correctRatio.keySet()) {
+            assertEquals(1.0, correctRatio.get(id).f0 / correctRatio.get(id).f1, 1.0e-5);
+            numModel++;
+        }
+        if (expectedNum != -1) {
+            assertEquals(expectedNum, numModel);
+        }
+    }
+
+    /** Random selects samples with fixed size from a given sample list. */
+    private static class RandomSample implements SourceFunction<Row> {
+        private volatile boolean isRunning = true;
+        private final long timeInterval;
+        private final long numSample;
+        private final List<Row> data;
+
+        public RandomSample(long timeInterval, long numSample, List<Row> data) {
+            this.timeInterval = timeInterval;
+            this.numSample = numSample;
+            this.data = data;
+        }
+
+        @Override
+        public void run(SourceContext<Row> ctx) throws Exception {
+            int size = data.size();
+            for (int i = 0; i < numSample; ++i) {
+                int idx = i % size;
+                if (isRunning) {
+                    ctx.collect(data.get(idx));
+                    if (timeInterval > 0) {
+                        Thread.sleep(timeInterval);
+                    }
+                }
+            }
+        }
+
+        @Override
+        public void cancel() {
+            isRunning = false;
+        }
+    }
+
+    private Table getTable(
+            int timeInterval, int numSample, List<Row> data, int parallel, boolean isSparse) {

Review Comment:
   For the first question, I don't think so. The data is random sampling from a fixed dataset.
   
   
   The user set the total num samples, we need to calculate the num samples on every worker, then we need source parallel parameter.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r868785218


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java:
##########
@@ -0,0 +1,444 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
+import org.apache.flink.ml.feature.MinMaxScalerTest;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */
+public class OnlineLogisticRegressionTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainDenseTable;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private static final double[] ONE_ARRAY = new double[] {1.0, 1.0, 1.0};
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 3}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {7, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.));
+
+    private Table initDenseModel;
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream).as(FEATURE_COL, LABEL_COL);
+        initDenseModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+    }
+
+    @Test
+    public void testFit() throws Exception {
+        Table onlinePredictTable = getTable(1, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testFitWithInitLrModel() throws Exception {
+        Table initLrModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+        Table onlineTrainTable = getTable(50, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initLrModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testDenseFitAndPredict() throws Exception {
+        Table onlineTrainTable = getTable(2, 2000, TRAIN_DENSE_ROWS, 2, false);
+        Table onlinePredictTable = getTable(2, 3000, PREDICT_DENSE_ROWS, 2, false);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);
+    }
+
+    @Test
+    public void testSparseFitAndPredict() throws Exception {
+        double[] doubleArray = new double[] {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
+        Table initSparseModel =
+                tEnv.fromDataStream(env.fromElements(Row.of(new DenseVector(doubleArray), 0L)));
+        Table onlineTrainTable = getTable(5, 2000, TRAIN_SPARSE_ROWS, 4, true);
+        Table onlinePredictTable = getTable(5, 3000, PREDICT_SPARSE_ROWS, 4, true);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initSparseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);

Review Comment:
   I will refine it later.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r854865460


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java:
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which classifies data using the model data computed by {@link OnlineLogisticRegression}.
+ */
+public class OnlineLogisticRegressionModel
+        implements Model<OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionModelParams<OnlineLogisticRegressionModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineLogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                Types.DOUBLE,
+                                TypeInformation.of(DenseVector.class),
+                                Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol(),
+                                getModelVersionCol()));

Review Comment:
   The behavior here is inconsistent with `OnlineKmeansModel`. We probably need to discuss about whether it is necessary to output the model version.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r882261550


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +193,64 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * An function that splits a global batch into evenly-sized local batches, and distributes them

Review Comment:
   Could you update the java doc here and add unit test to verify this function?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -18,40 +18,85 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.serialization.Encoder;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
+import java.util.Random;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;
 
-    public LogisticRegressionModelData(DenseVector coefficient) {
+    public LogisticRegressionModelData(DenseVector coefficient, long modelVersion) {
         this.coefficient = coefficient;
+        this.modelVersion = modelVersion;
     }
 
     public LogisticRegressionModelData() {}
 
+    public LogisticRegressionModelData(DenseVector coefficient) {
+        this.coefficient = coefficient;
+        this.modelVersion = 0L;

Review Comment:
   As we have agreed to add a model version, should we do that to all existing algorithms? Or leave a TODO?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java:
##########
@@ -161,7 +162,8 @@ public Row map(Row dataPoint) {
      * @param coefficient The model parameters.
      * @return The prediction label and the raw probabilities.
      */
-    private static Row predictOneDataPoint(DenseVector feature, DenseVector coefficient) {
+    public static Row predictOneDataPoint(Vector feature, DenseVector coefficient) {

Review Comment:
   nit: could be package private



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), getElasticNet());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     *
+     * <p>todo: makes ftrl to be a common optimizer and place it in org.apache.flink.ml.common in
+     * future.
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(
+                int batchSize, double alpha, double beta, double reg, double elasticNet) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();

Review Comment:
   Using BLAS here could be more efficient and make the code more readable. Could you make the change?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan

Review Comment:
   Could you please update the java doc here for online logistic regression rather than FTRL?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>

Review Comment:
   Do we need to support HasWeight for online logistic regression, similar as Logistic Regression?
   
   Or at least we add a todo here.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java:
##########
@@ -110,7 +111,9 @@ public LogisticRegressionModel fit(Table... inputs) {
                 optimizer.optimize(initModelData, trainData, BinaryLogisticLoss.INSTANCE);
 
         DataStream<LogisticRegressionModelData> modelData =
-                rawModelData.map(LogisticRegressionModelData::new);
+                rawModelData.map(

Review Comment:
   nit: The code change here seems unnessary.



##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java:
##########
@@ -65,12 +65,62 @@ public static void hDot(Vector x, Vector y) {
         }
     }
 
+    /** Computes the dot of the two vectors (y = y \dot x). */
+    public static double dot(Vector x, Vector y) {
+        Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched.");
+        if (x instanceof SparseVector) {
+            if (y instanceof SparseVector) {
+                return dot((SparseVector) x, (SparseVector) y);
+            } else {
+                return dot((SparseVector) x, (DenseVector) y);
+            }
+        } else {
+            if (y instanceof SparseVector) {
+                return dot((DenseVector) x, (SparseVector) y);
+            } else {
+                return dot((DenseVector) x, (DenseVector) y);
+            }
+        }
+    }
+
     /** x \cdot y . */
     public static double dot(DenseVector x, DenseVector y) {
         Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched.");
         return JAVA_BLAS.ddot(x.size(), x.values, 1, y.values, 1);
     }
 
+    public static double dot(DenseVector x, SparseVector y) {

Review Comment:
   nits: (1) These dot methods (except `dot(Vector, Vector)` could be private and (2) we could remove the check for vector sizes in these methods. They are covered in `dot(Vector, Vector)`



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java:
##########
@@ -232,7 +232,7 @@ public void testSaveLoadAndPredict() throws Exception {
         LogisticRegressionModel model = logisticRegression.fit(binomialDataTable);
         model = StageTestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
         assertEquals(
-                Collections.singletonList("coefficient"),
+                Arrays.asList("coefficient", "modelVersion"),

Review Comment:
   Please keep this issue open if it is now resolved~



##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java:
##########
@@ -65,12 +65,62 @@ public static void hDot(Vector x, Vector y) {
         }
     }
 
+    /** Computes the dot of the two vectors (y = y \dot x). */
+    public static double dot(Vector x, Vector y) {
+        Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched.");
+        if (x instanceof SparseVector) {
+            if (y instanceof SparseVector) {
+                return dot((SparseVector) x, (SparseVector) y);
+            } else {
+                return dot((SparseVector) x, (DenseVector) y);
+            }
+        } else {
+            if (y instanceof SparseVector) {
+                return dot((DenseVector) x, (SparseVector) y);
+            } else {
+                return dot((DenseVector) x, (DenseVector) y);
+            }
+        }
+    }
+
     /** x \cdot y . */
     public static double dot(DenseVector x, DenseVector y) {
         Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched.");
         return JAVA_BLAS.ddot(x.size(), x.values, 1, y.values, 1);
     }
 
+    public static double dot(DenseVector x, SparseVector y) {
+        Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched.");
+        double dotValue = 0.0;
+        for (int i = 0; i < y.indices.length; ++i) {
+            dotValue += y.values[i] * x.values[y.indices[i]];
+        }
+        return dotValue;
+    }
+
+    public static double dot(SparseVector x, DenseVector y) {

Review Comment:
   This method could be removed by calling `dot(DenseVector, SparseVector)`



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), getElasticNet());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     *
+     * <p>todo: makes ftrl to be a common optimizer and place it in org.apache.flink.ml.common in
+     * future.
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(
+                int batchSize, double alpha, double beta, double reg, double elasticNet) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData.map(new CreateLrModelData()).setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class CreateLrModelData
+            implements MapFunction<DenseVector, LogisticRegressionModelData>, CheckpointedFunction {
+        private Long modelVersion = 1L;
+        private transient ListState<Long> modelVersionState;
+
+        @Override
+        public LogisticRegressionModelData map(DenseVector denseVector) throws Exception {
+            return new LogisticRegressionModelData(denseVector, modelVersion++);
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext functionSnapshotContext)
+                throws Exception {
+            modelVersionState.update(Collections.singletonList(modelVersion));
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) throws Exception {
+            modelVersionState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelVersionState", Long.class));
+        }
+    }
+
+    /** Updates model. */
+    private static class UpdateModel extends AbstractStreamOperator<DenseVector>
+            implements OneInputStreamOperator<DenseVector[], DenseVector> {
+        private ListState<double[]> nParamState;
+        private ListState<double[]> zParamState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private double[] nParam;
+        private double[] zParam;
+
+        public UpdateModel(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            nParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("nParamState", double[].class));
+            zParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("zParamState", double[].class));
+        }
+
+        @Override
+        public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            DenseVector[] gradientInfo = streamRecord.getValue();
+            double[] coefficient = gradientInfo[2].values;
+            double[] g = gradientInfo[0].values;
+            for (int i = 0; i < g.length; ++i) {
+                if (gradientInfo[1].values[i] != 0.0) {
+                    g[i] = g[i] / gradientInfo[1].values[i];
+                }
+            }
+            if (zParam == null) {
+                zParam = new double[g.length];
+                nParam = new double[g.length];
+                nParamState.add(nParam);
+                zParamState.add(zParam);
+            }
+
+            for (int i = 0; i < zParam.length; ++i) {
+                double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - Math.sqrt(nParam[i])) / alpha;
+                zParam[i] += g[i] - sigma * coefficient[i];
+                nParam[i] += g[i] * g[i];
+
+                if (Math.abs(zParam[i]) <= l1) {
+                    coefficient[i] = 0.0;
+                } else {
+                    coefficient[i] =
+                            ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i])
+                                    / ((beta + Math.sqrt(nParam[i])) / alpha + l2);
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector(coefficient)));
+        }
+    }
+
+    private static class CalculateLocalGradient extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<Row[], DenseVector, DenseVector[]> {
+        private ListState<DenseVector> modelDataState;
+        private ListState<Row[]> localBatchDataState;
+        private double[] gradient;
+        private double[] weight;
+        private int[] denseVectorIndices;
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector.class));
+            TypeInformation<Row[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class));
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row[]> pointsRecord) throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            calculateGradient();
+        }
+
+        private void calculateGradient() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+            DenseVector modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Row[]> pointsList = IteratorUtils.toList(localBatchDataState.get().iterator());
+            Row[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Row point : points) {
+                Vector vec = point.getFieldAs(0);
+                double label = point.getFieldAs(1);
+                if (gradient == null) {
+                    gradient = new double[vec.size()];
+                    weight = new double[gradient.length];
+                    if (vec instanceof DenseVector) {
+                        denseVectorIndices = new int[vec.size()];
+                        for (int i = 0; i < denseVectorIndices.length; ++i) {
+                            denseVectorIndices[i] = i;
+                        }
+                    }
+                }
+
+                int[] indices;
+                double[] values;
+                if (vec instanceof DenseVector) {

Review Comment:
   Could this be simplified with Blas.dot(Vector, Vector)?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -18,40 +18,85 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.serialization.Encoder;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
+import java.util.Random;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;
 
-    public LogisticRegressionModelData(DenseVector coefficient) {
+    public LogisticRegressionModelData(DenseVector coefficient, long modelVersion) {
         this.coefficient = coefficient;
+        this.modelVersion = modelVersion;
     }
 
     public LogisticRegressionModelData() {}
 
+    public LogisticRegressionModelData(DenseVector coefficient) {
+        this.coefficient = coefficient;
+        this.modelVersion = 0L;
+    }
+
+    /**
+     * Generates a Table containing a {@link LogisticRegressionModelData} instance with randomly
+     * generated coefficient.
+     *
+     * @param tEnv The environment where to create the table.
+     * @param dim The size of generated coefficient.
+     * @param seed Random seed.
+     */
+    public static Table generateRandomModelData(StreamTableEnvironment tEnv, int dim, int seed) {

Review Comment:
   This method seems never used/tested. Could you add a unit test here?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);

Review Comment:
   I hold the same opinion with @yunfengzhou-hub . If the init verision is not zero, the version of the output model should be larger than the init version.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), getElasticNet());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     *
+     * <p>todo: makes ftrl to be a common optimizer and place it in org.apache.flink.ml.common in
+     * future.
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(
+                int batchSize, double alpha, double beta, double reg, double elasticNet) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData.map(new CreateLrModelData()).setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class CreateLrModelData
+            implements MapFunction<DenseVector, LogisticRegressionModelData>, CheckpointedFunction {
+        private Long modelVersion = 1L;
+        private transient ListState<Long> modelVersionState;
+
+        @Override
+        public LogisticRegressionModelData map(DenseVector denseVector) throws Exception {
+            return new LogisticRegressionModelData(denseVector, modelVersion++);
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext functionSnapshotContext)
+                throws Exception {
+            modelVersionState.update(Collections.singletonList(modelVersion));
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) throws Exception {
+            modelVersionState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelVersionState", Long.class));
+        }
+    }
+
+    /** Updates model. */
+    private static class UpdateModel extends AbstractStreamOperator<DenseVector>
+            implements OneInputStreamOperator<DenseVector[], DenseVector> {
+        private ListState<double[]> nParamState;
+        private ListState<double[]> zParamState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private double[] nParam;
+        private double[] zParam;
+
+        public UpdateModel(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            nParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("nParamState", double[].class));
+            zParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("zParamState", double[].class));
+        }
+
+        @Override
+        public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            DenseVector[] gradientInfo = streamRecord.getValue();
+            double[] coefficient = gradientInfo[2].values;
+            double[] g = gradientInfo[0].values;
+            for (int i = 0; i < g.length; ++i) {
+                if (gradientInfo[1].values[i] != 0.0) {
+                    g[i] = g[i] / gradientInfo[1].values[i];
+                }
+            }
+            if (zParam == null) {
+                zParam = new double[g.length];
+                nParam = new double[g.length];
+                nParamState.add(nParam);
+                zParamState.add(zParam);
+            }
+
+            for (int i = 0; i < zParam.length; ++i) {
+                double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - Math.sqrt(nParam[i])) / alpha;
+                zParam[i] += g[i] - sigma * coefficient[i];
+                nParam[i] += g[i] * g[i];
+
+                if (Math.abs(zParam[i]) <= l1) {
+                    coefficient[i] = 0.0;
+                } else {
+                    coefficient[i] =
+                            ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i])
+                                    / ((beta + Math.sqrt(nParam[i])) / alpha + l2);
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector(coefficient)));
+        }
+    }
+
+    private static class CalculateLocalGradient extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<Row[], DenseVector, DenseVector[]> {
+        private ListState<DenseVector> modelDataState;
+        private ListState<Row[]> localBatchDataState;
+        private double[] gradient;
+        private double[] weight;
+        private int[] denseVectorIndices;
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector.class));
+            TypeInformation<Row[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class));
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row[]> pointsRecord) throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            calculateGradient();
+        }
+
+        private void calculateGradient() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+            DenseVector modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Row[]> pointsList = IteratorUtils.toList(localBatchDataState.get().iterator());
+            Row[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Row point : points) {
+                Vector vec = point.getFieldAs(0);
+                double label = point.getFieldAs(1);
+                if (gradient == null) {
+                    gradient = new double[vec.size()];
+                    weight = new double[gradient.length];
+                    if (vec instanceof DenseVector) {
+                        denseVectorIndices = new int[vec.size()];
+                        for (int i = 0; i < denseVectorIndices.length; ++i) {
+                            denseVectorIndices[i] = i;
+                        }
+                    }
+                }
+
+                int[] indices;
+                double[] values;
+                if (vec instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) vec;
+                    indices = denseVectorIndices;
+                    values = denseVector.values;
+                } else {
+                    SparseVector sparseVector = (SparseVector) vec;
+                    indices = sparseVector.indices;
+                    values = sparseVector.values;
+                }
+                double p = 0.0;
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    p += modelData.values[idx] * values[i];
+                }
+                p = 1 / (1 + Math.exp(-p));
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    gradient[idx] += (p - label) * values[i];
+                    weight[idx] += 1.0;
+                }
+            }
+
+            if (points.length > 0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new DenseVector[] {
+                                    new DenseVector(gradient),
+                                    new DenseVector(weight),

Review Comment:
   It seems that we do not need to output the `weight` to the downstream operator since using reduced `gradient` is enough for updating the local `weight`. (Tensorflow also does this)
   
   Moreover, for outputing the models, we could use another `OutputTag<LogisticRegressionModelData>` and do not mess it up with the communication. Please check out SGD#Line292 for example.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r882428114


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);

Review Comment:
   OK, I will refine it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r882484808


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), getElasticNet());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     *
+     * <p>todo: makes ftrl to be a common optimizer and place it in org.apache.flink.ml.common in
+     * future.
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(
+                int batchSize, double alpha, double beta, double reg, double elasticNet) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData.map(new CreateLrModelData()).setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class CreateLrModelData
+            implements MapFunction<DenseVector, LogisticRegressionModelData>, CheckpointedFunction {
+        private Long modelVersion = 1L;
+        private transient ListState<Long> modelVersionState;
+
+        @Override
+        public LogisticRegressionModelData map(DenseVector denseVector) throws Exception {
+            return new LogisticRegressionModelData(denseVector, modelVersion++);
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext functionSnapshotContext)
+                throws Exception {
+            modelVersionState.update(Collections.singletonList(modelVersion));
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) throws Exception {
+            modelVersionState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelVersionState", Long.class));
+        }
+    }
+
+    /** Updates model. */
+    private static class UpdateModel extends AbstractStreamOperator<DenseVector>
+            implements OneInputStreamOperator<DenseVector[], DenseVector> {
+        private ListState<double[]> nParamState;
+        private ListState<double[]> zParamState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private double[] nParam;
+        private double[] zParam;
+
+        public UpdateModel(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            nParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("nParamState", double[].class));
+            zParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("zParamState", double[].class));
+        }
+
+        @Override
+        public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            DenseVector[] gradientInfo = streamRecord.getValue();
+            double[] coefficient = gradientInfo[2].values;
+            double[] g = gradientInfo[0].values;
+            for (int i = 0; i < g.length; ++i) {
+                if (gradientInfo[1].values[i] != 0.0) {
+                    g[i] = g[i] / gradientInfo[1].values[i];
+                }
+            }
+            if (zParam == null) {
+                zParam = new double[g.length];
+                nParam = new double[g.length];
+                nParamState.add(nParam);
+                zParamState.add(zParam);
+            }
+
+            for (int i = 0; i < zParam.length; ++i) {
+                double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - Math.sqrt(nParam[i])) / alpha;
+                zParam[i] += g[i] - sigma * coefficient[i];
+                nParam[i] += g[i] * g[i];
+
+                if (Math.abs(zParam[i]) <= l1) {
+                    coefficient[i] = 0.0;
+                } else {
+                    coefficient[i] =
+                            ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i])
+                                    / ((beta + Math.sqrt(nParam[i])) / alpha + l2);
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector(coefficient)));
+        }
+    }
+
+    private static class CalculateLocalGradient extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<Row[], DenseVector, DenseVector[]> {
+        private ListState<DenseVector> modelDataState;
+        private ListState<Row[]> localBatchDataState;
+        private double[] gradient;
+        private double[] weight;
+        private int[] denseVectorIndices;
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector.class));
+            TypeInformation<Row[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class));
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row[]> pointsRecord) throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            calculateGradient();
+        }
+
+        private void calculateGradient() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+            DenseVector modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Row[]> pointsList = IteratorUtils.toList(localBatchDataState.get().iterator());
+            Row[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Row point : points) {
+                Vector vec = point.getFieldAs(0);
+                double label = point.getFieldAs(1);
+                if (gradient == null) {
+                    gradient = new double[vec.size()];
+                    weight = new double[gradient.length];
+                    if (vec instanceof DenseVector) {
+                        denseVectorIndices = new int[vec.size()];
+                        for (int i = 0; i < denseVectorIndices.length; ++i) {
+                            denseVectorIndices[i] = i;
+                        }
+                    }
+                }
+
+                int[] indices;
+                double[] values;
+                if (vec instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) vec;
+                    indices = denseVectorIndices;
+                    values = denseVector.values;
+                } else {
+                    SparseVector sparseVector = (SparseVector) vec;
+                    indices = sparseVector.indices;
+                    values = sparseVector.values;
+                }
+                double p = 0.0;
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    p += modelData.values[idx] * values[i];
+                }
+                p = 1 / (1 + Math.exp(-p));
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    gradient[idx] += (p - label) * values[i];
+                    weight[idx] += 1.0;
+                }
+            }
+
+            if (points.length > 0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new DenseVector[] {
+                                    new DenseVector(gradient),
+                                    new DenseVector(weight),

Review Comment:
   The `weight`  will be renamed as `weightSum`. For as above comment I add weightCol param, this variable is just the sum of weight. 
   
   I will add java doc for this code and I will refine the pattern of output model.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r887422996


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);

Review Comment:
   1. The init model is a offline model. Adding a version to the offline model is just to keep the offline and online model have the same format and they can share the same modelData processing function. In fact, version has no meaning in offline model. 
   
   2. The online model version is bigger than the init model version has no meaning. If an online algorithm is running and producing models one by one, we just want to keep the i+1 model's version is bigger then the i model's version and the init model is a special one.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r881264766


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java:
##########
@@ -0,0 +1,661 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel.MODEL_DATA_VERSION_GAUGE_KEY;
+
+/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */
+public class OnlineLogisticRegressionTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final double[] ONE_ARRAY = new double[] {1.0, 1.0, 1.0};
+
+    private static final Row[] TRAIN_DENSE_ROWS_1 =
+            new Row[] {
+                Row.of(Vectors.dense(0.1, 2.), 0.),
+                Row.of(Vectors.dense(0.2, 2.), 0.),
+                Row.of(Vectors.dense(0.3, 2.), 0.),
+                Row.of(Vectors.dense(0.4, 2.), 0.),
+                Row.of(Vectors.dense(0.5, 2.), 0.),
+                Row.of(Vectors.dense(11., 12.), 1.),
+                Row.of(Vectors.dense(12., 11.), 1.),
+                Row.of(Vectors.dense(13., 12.), 1.),
+                Row.of(Vectors.dense(14., 12.), 1.),
+                Row.of(Vectors.dense(15., 12.), 1.)
+            };
+
+    private static final Row[] TRAIN_DENSE_ROWS_2 =
+            new Row[] {
+                Row.of(Vectors.dense(0.2, 3.), 0.),
+                Row.of(Vectors.dense(0.8, 1.), 0.),
+                Row.of(Vectors.dense(0.7, 1.), 0.),
+                Row.of(Vectors.dense(0.6, 2.), 0.),
+                Row.of(Vectors.dense(0.2, 2.), 0.),
+                Row.of(Vectors.dense(14., 17.), 1.),
+                Row.of(Vectors.dense(15., 10.), 1.),
+                Row.of(Vectors.dense(16., 16.), 1.),
+                Row.of(Vectors.dense(17., 10.), 1.),
+                Row.of(Vectors.dense(18., 13.), 1.)
+            };
+
+    private static final Row[] PREDICT_DENSE_ROWS =
+            new Row[] {
+                Row.of(Vectors.dense(0.8, 2.7), 0.0), Row.of(Vectors.dense(15.5, 11.2), 1.0)
+            };
+
+    private static final Row[] TRAIN_SPARSE_ROWS_1 =
+            new Row[] {
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {0, 2, 3}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {0, 3, 4}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {6, 7, 8}, ONE_ARRAY), 1.),
+                Row.of(Vectors.sparse(10, new int[] {6, 8, 9}, ONE_ARRAY), 1.),
+                Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.),
+                Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.)

Review Comment:
   If we use the same data, some sparse properties may not covered by these ut. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r881313632


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBeta(),
+                        getReg(),
+                        getElasticNet(),
+                        getModelSaveInterval());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private long modelVersion = 1L;
+        private final int modelSaveInterval;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double reg,
+                double elasticNet,
+                int modelSaveInterval) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+            this.modelSaveInterval = modelSaveInterval;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .filter(
+                                    new FilterFunction<DenseVector>() {
+                                        private int step = 0;
+
+                                        @Override
+                                        public boolean filter(DenseVector denseVector) {
+                                            step++;
+                                            return step % modelSaveInterval == 0;
+                                        }
+                                    })
+                            .setParallelism(1)

Review Comment:
   Got it.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java:
##########
@@ -0,0 +1,661 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel.MODEL_DATA_VERSION_GAUGE_KEY;
+
+/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */
+public class OnlineLogisticRegressionTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final double[] ONE_ARRAY = new double[] {1.0, 1.0, 1.0};
+
+    private static final Row[] TRAIN_DENSE_ROWS_1 =
+            new Row[] {
+                Row.of(Vectors.dense(0.1, 2.), 0.),
+                Row.of(Vectors.dense(0.2, 2.), 0.),
+                Row.of(Vectors.dense(0.3, 2.), 0.),
+                Row.of(Vectors.dense(0.4, 2.), 0.),
+                Row.of(Vectors.dense(0.5, 2.), 0.),
+                Row.of(Vectors.dense(11., 12.), 1.),
+                Row.of(Vectors.dense(12., 11.), 1.),
+                Row.of(Vectors.dense(13., 12.), 1.),
+                Row.of(Vectors.dense(14., 12.), 1.),
+                Row.of(Vectors.dense(15., 12.), 1.)
+            };
+
+    private static final Row[] TRAIN_DENSE_ROWS_2 =
+            new Row[] {
+                Row.of(Vectors.dense(0.2, 3.), 0.),
+                Row.of(Vectors.dense(0.8, 1.), 0.),
+                Row.of(Vectors.dense(0.7, 1.), 0.),
+                Row.of(Vectors.dense(0.6, 2.), 0.),
+                Row.of(Vectors.dense(0.2, 2.), 0.),
+                Row.of(Vectors.dense(14., 17.), 1.),
+                Row.of(Vectors.dense(15., 10.), 1.),
+                Row.of(Vectors.dense(16., 16.), 1.),
+                Row.of(Vectors.dense(17., 10.), 1.),
+                Row.of(Vectors.dense(18., 13.), 1.)
+            };
+
+    private static final Row[] PREDICT_DENSE_ROWS =
+            new Row[] {
+                Row.of(Vectors.dense(0.8, 2.7), 0.0), Row.of(Vectors.dense(15.5, 11.2), 1.0)
+            };
+
+    private static final Row[] TRAIN_SPARSE_ROWS_1 =
+            new Row[] {
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {0, 2, 3}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {0, 3, 4}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {6, 7, 8}, ONE_ARRAY), 1.),
+                Row.of(Vectors.sparse(10, new int[] {6, 8, 9}, ONE_ARRAY), 1.),
+                Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.),
+                Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.)

Review Comment:
   Got it. Thanks.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r881379720


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBeta(),
+                        getReg(),
+                        getElasticNet(),
+                        getModelSaveInterval());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private long modelVersion = 1L;
+        private final int modelSaveInterval;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double reg,
+                double elasticNet,
+                int modelSaveInterval) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+            this.modelSaveInterval = modelSaveInterval;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .filter(
+                                    new FilterFunction<DenseVector>() {
+                                        private int step = 0;
+
+                                        @Override
+                                        public boolean filter(DenseVector denseVector) {
+                                            step++;
+                                            return step % modelSaveInterval == 0;
+                                        }
+                                    })
+                            .setParallelism(1)
+                            .map(
+                                    (MapFunction<DenseVector, LogisticRegressionModelData>)
+                                            value ->
+                                                    new LogisticRegressionModelData(
+                                                            value, modelVersion++))
+                            .setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData));
+        }
+    }
+
+    /** Updates model. */
+    private static class UpdateModel extends AbstractStreamOperator<DenseVector>
+            implements OneInputStreamOperator<DenseVector[], DenseVector> {
+        private ListState<double[]> nParamState;
+        private ListState<double[]> zParamState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private double[] nParam;
+        private double[] zParam;
+
+        public UpdateModel(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            nParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("nParamState", double[].class));
+            zParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("zParamState", double[].class));
+        }
+
+        @Override
+        public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            DenseVector[] gradientInfo = streamRecord.getValue();
+            double[] coefficient = gradientInfo[2].values;
+            double[] g = gradientInfo[0].values;
+            for (int i = 0; i < g.length; ++i) {
+                if (gradientInfo[1].values[i] != 0.0) {
+                    g[i] = g[i] / gradientInfo[1].values[i];
+                }
+            }
+            if (zParam == null) {
+                zParam = new double[g.length];
+                nParam = new double[g.length];
+                nParamState.add(nParam);
+                zParamState.add(zParam);
+            }
+
+            for (int i = 0; i < zParam.length; ++i) {
+                double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - Math.sqrt(nParam[i])) / alpha;
+                zParam[i] += g[i] - sigma * coefficient[i];
+                nParam[i] += g[i] * g[i];
+
+                if (Math.abs(zParam[i]) <= l1) {
+                    coefficient[i] = 0.0;
+                } else {
+                    coefficient[i] =
+                            ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i])
+                                    / ((beta + Math.sqrt(nParam[i])) / alpha + l2);
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector(coefficient)));
+        }
+    }
+
+    private static class CalculateLocalGradient extends AbstractStreamOperator<DenseVector[]>

Review Comment:
   add todo.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator of Ftrl

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r847844186


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*

Review Comment:
   Jira tickets for algorithms that are supposed to be added have all been created in advance. You can find the ticket for FTRL on https://issues.apache.org/jira/secure/RapidBoard.jspa?rapidView=541. The ticket for FTRL is FLINK-20790.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {

Review Comment:
   The test cases are arranged differently from existing practice. Let's add tests that each covers the following situations.
   - tests getting/setting parameters
   - tests the most common fit/transform process.
   - tests save/load.
   - tests getting/setting model data.
   - tests invalid inputs/corner cases.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LinearModelData.java:
##########
@@ -37,60 +38,68 @@
 import java.io.OutputStream;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel}, {@link FtrlModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
-public class LogisticRegressionModelData {
+public class LinearModelData {

Review Comment:
   Could you please illustrate the relationship between FTRL and LogisticRegression, and other algorithms like LinearRegression? I'm not sure why we would like to rename `LogisticRegressionModelData` as `LinearModelData`.
   
   If after discussion we still agree that this renaming is reasonable, it would mean that the model data class neither belongs `logisticregresson` or `ftrl` package. We would need to place classes like this to a neutral package, like something named `common`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LinearModelData.java:
##########
@@ -37,60 +38,68 @@
 import java.io.OutputStream;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel}, {@link FtrlModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
-public class LogisticRegressionModelData {
+public class LinearModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;

Review Comment:
   The versioning mechanism is different from that in `OnlineKMeans`. Shall we adopt the same practice across both online algorithms?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlParams.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Ftrl}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlParams<T>
+        extends HasLabelCol<T>, HasBatchStrategy<T>, HasGlobalBatchSize<T>, HasFeaturesCol<T> {

Review Comment:
   An `Estimator`'s param class should inherit the corresponding `Model`'s param class. In this current implementation, `Ftrl` would not be able to set the `rawPredictionCol` and `predictionCol` of the generated `FtrlModel`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] {value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData map(DenseVector[] value) {
+                                                    return new LinearModelData(value[0], iter++);
+                                                }
+                                            })

Review Comment:
   Shall we move logics like this to a separated class or method? That could make the code look prettier.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private Table trainDenseTable;
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),

Review Comment:
   Shall we make `new double[]{1.0, 1.0, 1.0}` a variable? It might make the code look more pretty.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private Table trainDenseTable;
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {0, 2, 3}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 3, 4}, new double[] {1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 8}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 8}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 2, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 2, 4}, new double[] {1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {7, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 7, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.of(DenseVector.class))
+                        .column("f1", DataTypes.DOUBLE())
+                        .build();
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream, schema).as(FEATURE_COL, LABEL_COL);
+    }
+
+    @Test
+    public void testFtrlWithInitLrModel() throws Exception {
+        Table initModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrl() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrlModel() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+
+        Table onlineTrainTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, PREDICT_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        FtrlModel model =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable);
+    }
+
+    @Test
+    public void testFtrlModelSparse() throws Exception {
+        Table initModelSparse =
+                tEnv.fromDataStream(
+                        env.fromElements(
+                                Row.of(
+                                        new DenseVector(
+                                                new double[] {
+                                                    0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1
+                                                }),
+                                        0L)));
+
+        Table onlineTrainTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, TRAIN_SPARSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, PREDICT_SPARSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        FtrlModel model =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModelSparse)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        tEnv.toDataStream(model.getModelData()[0]).print();
+        verifyPredictionResult(resultTable);
+    }
+
+    private static void verifyPredictionResult(Table output) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+        DataStream<Row> stream = tEnv.toDataStream(output);
+        List<Row> result = IteratorUtils.toList(stream.executeAndCollect());
+        Map<Long, Tuple2<Double, Double>> correctRatio = new HashMap<>();
+
+        for (Row row : result) {
+            long modelVersion = row.getFieldAs(MODEL_VERSION_COL);
+            Double pred = row.getFieldAs(PREDICT_COL);
+            Double label = row.getFieldAs(LABEL_COL);
+            if (correctRatio.containsKey(modelVersion)) {
+                Tuple2<Double, Double> t2 = correctRatio.get(modelVersion);
+                if (pred.equals(label)) {
+                    t2.f0 += 1.0;
+                }
+                t2.f1 += 1.0;
+            } else {
+                correctRatio.put(modelVersion, Tuple2.of(pred.equals(label) ? 1.0 : 0.0, 1.0));
+            }
+        }
+        for (Long id : correctRatio.keySet()) {
+            System.out.println(
+                    id
+                            + " : "
+                            + correctRatio.get(id).f0 / correctRatio.get(id).f1
+                            + " total sample num : "
+                            + correctRatio.get(id).f1);
+            if (id > 0L) {
+                assertEquals(1.0, correctRatio.get(id).f0 / correctRatio.get(id).f1, 1.0e-5);
+            }
+        }
+    }
+
+    /** Generates random data for ftrl train and predict. */
+    public static class RandomSourceFunction implements SourceFunction<Row> {
+        private volatile boolean isRunning = true;
+        private final long timeInterval;
+        private final long maxSize;
+        private final List<Row> data;
+
+        public RandomSourceFunction(long timeInterval, long maxSize, List<Row> data)
+                throws InterruptedException {
+            this.timeInterval = timeInterval;
+            this.maxSize = maxSize;
+            this.data = data;
+        }
+
+        @Override
+        public void run(SourceContext<Row> ctx) throws Exception {
+            int size = data.size();
+            for (int i = 0; i < maxSize; ++i) {
+                if (i == 0) {
+                    Thread.sleep(5000);

Review Comment:
   Shall we avoid using `Thread.sleep()` in test cases? If every unit test adopt this practice, the total time for `mvn install` would be very long.
   
   What we actually want to do here is just to make sure the job has been initialized before any input is provided. `OnlineKMeansTest` has established a best practice for such kind of problem, using `InMemorySourceFunction`, you can refer to these classes for how to solve this problem.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private Table trainDenseTable;
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {0, 2, 3}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 3, 4}, new double[] {1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 8}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 8}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 2, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 2, 4}, new double[] {1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {7, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 7, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.of(DenseVector.class))
+                        .column("f1", DataTypes.DOUBLE())
+                        .build();
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream, schema).as(FEATURE_COL, LABEL_COL);
+    }
+
+    @Test
+    public void testFtrlWithInitLrModel() throws Exception {
+        Table initModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrl() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();

Review Comment:
   Let's examine the correctness of the output of all test cases. `env.execute()` only guarantees that the job does not throw exception during its execution, while it does not mean the calculation result is correct.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlModelParams.java:
##########
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+
+/**
+ * Params for {@link FtrlModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlModelParams<T>
+        extends HasFeaturesCol<T>, HasPredictionCol<T>, HasRawPredictionCol<T> {}

Review Comment:
   This interface is identical to `LogisticRegressionModelParams`. If FTRL is an online version of LogisticRegression, shall we consider reorganizing code to avoid creating duplicate classes like this?



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private Table trainDenseTable;
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {0, 2, 3}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 3, 4}, new double[] {1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 8}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 8}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 2, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 2, 4}, new double[] {1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {7, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 7, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.of(DenseVector.class))
+                        .column("f1", DataTypes.DOUBLE())
+                        .build();
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream, schema).as(FEATURE_COL, LABEL_COL);
+    }
+
+    @Test
+    public void testFtrlWithInitLrModel() throws Exception {
+        Table initModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrl() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrlModel() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+
+        Table onlineTrainTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, PREDICT_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        FtrlModel model =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable);
+    }
+
+    @Test
+    public void testFtrlModelSparse() throws Exception {
+        Table initModelSparse =
+                tEnv.fromDataStream(
+                        env.fromElements(
+                                Row.of(
+                                        new DenseVector(
+                                                new double[] {
+                                                    0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1
+                                                }),
+                                        0L)));
+
+        Table onlineTrainTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, TRAIN_SPARSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, PREDICT_SPARSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, LABEL_COL})));
+
+        FtrlModel model =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModelSparse)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        tEnv.toDataStream(model.getModelData()[0]).print();
+        verifyPredictionResult(resultTable);
+    }
+
+    private static void verifyPredictionResult(Table output) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+        DataStream<Row> stream = tEnv.toDataStream(output);
+        List<Row> result = IteratorUtils.toList(stream.executeAndCollect());
+        Map<Long, Tuple2<Double, Double>> correctRatio = new HashMap<>();
+
+        for (Row row : result) {
+            long modelVersion = row.getFieldAs(MODEL_VERSION_COL);
+            Double pred = row.getFieldAs(PREDICT_COL);
+            Double label = row.getFieldAs(LABEL_COL);
+            if (correctRatio.containsKey(modelVersion)) {
+                Tuple2<Double, Double> t2 = correctRatio.get(modelVersion);
+                if (pred.equals(label)) {
+                    t2.f0 += 1.0;
+                }
+                t2.f1 += 1.0;
+            } else {
+                correctRatio.put(modelVersion, Tuple2.of(pred.equals(label) ? 1.0 : 0.0, 1.0));
+            }
+        }
+        for (Long id : correctRatio.keySet()) {
+            System.out.println(

Review Comment:
   Let's avoid printing debugging information in test cases.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlParams.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Ftrl}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlParams<T>
+        extends HasLabelCol<T>, HasBatchStrategy<T>, HasGlobalBatchSize<T>, HasFeaturesCol<T> {
+
+    Param<Integer> VECTOR_SIZE =
+            new IntParam("vectorSize", "The size of vector.", -1, ParamValidators.gt(-2));
+
+    default Integer getVectorSize() {

Review Comment:
   This param seems unused.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlParams.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Ftrl}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlParams<T>
+        extends HasLabelCol<T>, HasBatchStrategy<T>, HasGlobalBatchSize<T>, HasFeaturesCol<T> {
+
+    Param<Integer> VECTOR_SIZE =
+            new IntParam("vectorSize", "The size of vector.", -1, ParamValidators.gt(-2));
+
+    default Integer getVectorSize() {
+        return get(VECTOR_SIZE);
+    }
+
+    default T setVectorSize(Integer value) {
+        return set(VECTOR_SIZE, value);
+    }
+
+    Param<Double> L_1 =
+            new DoubleParam("l1", "The parameter l1 of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getL1() {
+        return get(L_1);
+    }
+
+    default T setL1(Double value) {
+        return set(L_1, value);
+    }
+
+    Param<Double> L_2 =
+            new DoubleParam("l2", "The parameter l2 of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getL2() {

Review Comment:
   It seems that these parameters are also used in other algorithms like soft max and multi-layer perception. Let's define them as common parameters.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlParams.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Ftrl}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlParams<T>
+        extends HasLabelCol<T>, HasBatchStrategy<T>, HasGlobalBatchSize<T>, HasFeaturesCol<T> {

Review Comment:
   It seems that some parameters provided by LogisticRegression, like `weightCol`, `reg` and `multiClass`, are not supported by FTRL for now. Is FTRL supposed to have weaker functionality than LogisticRegression, if both used in offline training process? If not, do we have any plan of adding support for these parameters in future?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] {value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData map(DenseVector[] value) {
+                                                    return new LinearModelData(value[0], iter++);
+                                                }
+                                            })
+                                    .setParallelism(1)));
+        }
+    }
+
+    /** Gets vector data. */
+    public static class GetVectorData implements MapFunction<LinearModelData, DenseVector[]> {
+        @Override
+        public DenseVector[] map(LinearModelData value) throws Exception {
+            return new DenseVector[] {value.coefficient};
+        }
+    }
+
+    /**
+     * Operator that collects a LogisticRegressionModelData from each upstream subtask, and outputs
+     * the weight average of collected model data.
+     */
+    public static class FtrlGlobalReducer implements ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] newModelData) {
+            for (int i = 0; i < newModelData[0].size(); ++i) {
+                if ((modelData[1].values[i] + newModelData[1].values[i]) > 0.0) {
+                    newModelData[0].values[i] =
+                            (modelData[0].values[i] * modelData[1].values[i]
+                                            + newModelData[0].values[i] * newModelData[1].values[i])
+                                    / (modelData[1].values[i] + newModelData[1].values[i]);
+                }
+                newModelData[1].values[i] = modelData[1].values[i] + newModelData[1].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    public static class FtrlLocalUpdater extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<
+                    Tuple2<Vector, Double>[], DenseVector[], DenseVector[]> {
+        private ListState<Tuple2<Vector, Double>[]> localBatchDataState;
+        private ListState<DenseVector[]> modelDataState;
+        private double[] n;
+        private double[] z;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private DenseVector weights;
+
+        public FtrlLocalUpdater(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<Tuple2<Vector, Double>[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector[].class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Tuple2<Vector, Double>[]> pointsRecord)
+                throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {

Review Comment:
   Would it be better to add some comments to this method, or divide this method further into several smaller methods? That could help improve readability.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlModel.java:
##########
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which classifies data using the model data computed by {@link Ftrl}. */
+public class FtrlModel implements Model<FtrlModel>, FtrlModelParams<FtrlModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public FtrlModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.DOUBLE, Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(), getPredictionCol(), "modelVersion"));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(LinearModelData.getModelDataStream(modelDataTable).broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, LinearModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            "MODEL_DATA_VERSION_GAUGE_KEY",
+                            (Gauge<String>) () -> Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws Exception {
+            Row dataPoint = streamRecord.getValue();
+            // todo : predict data
+            Vector features = (Vector) dataPoint.getField(featuresCol);
+            Tuple2<Double, DenseVector> predictionResult = predictRaw(features, coefficient);
+            output.collect(
+                    new StreamRecord<>(
+                            Row.join(dataPoint, Row.of(predictionResult.f0, modelDataVersion))));
+        }
+
+        @Override
+        public void processElement2(StreamRecord<LinearModelData> streamRecord) throws Exception {
+            LinearModelData modelData = streamRecord.getValue();
+
+            coefficient = modelData.coefficient;
+            modelDataVersion = modelData.modelVersion;
+            // System.out.println("update model...");
+            // todo : receive model data.
+            // Preconditions.checkArgument(modelData.centroids.length <= k);
+            // centroids = modelData.centroids;
+            // modelDataVersion++;
+            // for (Row dataPoint : bufferedPointsState.get()) {
+            //	processElement1(new StreamRecord<>(dataPoint));
+            // }
+            // bufferedPointsState.clear();

Review Comment:
   Let's reformat the code to remove unused codes.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] {value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData map(DenseVector[] value) {
+                                                    return new LinearModelData(value[0], iter++);
+                                                }
+                                            })
+                                    .setParallelism(1)));
+        }
+    }
+
+    /** Gets vector data. */
+    public static class GetVectorData implements MapFunction<LinearModelData, DenseVector[]> {
+        @Override
+        public DenseVector[] map(LinearModelData value) throws Exception {
+            return new DenseVector[] {value.coefficient};
+        }
+    }
+
+    /**
+     * Operator that collects a LogisticRegressionModelData from each upstream subtask, and outputs
+     * the weight average of collected model data.
+     */
+    public static class FtrlGlobalReducer implements ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] newModelData) {
+            for (int i = 0; i < newModelData[0].size(); ++i) {
+                if ((modelData[1].values[i] + newModelData[1].values[i]) > 0.0) {
+                    newModelData[0].values[i] =
+                            (modelData[0].values[i] * modelData[1].values[i]
+                                            + newModelData[0].values[i] * newModelData[1].values[i])
+                                    / (modelData[1].values[i] + newModelData[1].values[i]);
+                }
+                newModelData[1].values[i] = modelData[1].values[i] + newModelData[1].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    public static class FtrlLocalUpdater extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<
+                    Tuple2<Vector, Double>[], DenseVector[], DenseVector[]> {
+        private ListState<Tuple2<Vector, Double>[]> localBatchDataState;
+        private ListState<DenseVector[]> modelDataState;
+        private double[] n;
+        private double[] z;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private DenseVector weights;
+
+        public FtrlLocalUpdater(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<Tuple2<Vector, Double>[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector[].class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Tuple2<Vector, Double>[]> pointsRecord)
+                throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Tuple2<Vector, Double>[]> pointsList =
+                    IteratorUtils.toList(localBatchDataState.get().iterator());
+            Tuple2<Vector, Double>[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Tuple2<Vector, Double> point : points) {
+                if (n == null) {
+                    n = new double[point.f0.size()];
+                    z = new double[n.length];
+                    weights = new DenseVector(n.length);
+                }
+
+                double p = 0.0;
+                Arrays.fill(weights.values, 0.0);
+                if (point.f0 instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) point.f0;
+                    for (int i = 0; i < denseVector.size(); ++i) {
+                        if (Math.abs(z[i]) <= l1) {
+                            modelData[0].values[i] = 0.0;
+                        } else {
+                            modelData[0].values[i] =
+                                    ((z[i] < 0 ? -1 : 1) * l1 - z[i])
+                                            / ((beta + Math.sqrt(n[i])) / alpha + l2);
+                        }
+                        p += modelData[0].values[i] * denseVector.values[i];
+                    }
+                    p = 1 / (1 + Math.exp(-p));
+                    for (int i = 0; i < denseVector.size(); ++i) {
+                        double g = (p - point.f1) * denseVector.values[i];
+                        double sigma = (Math.sqrt(n[i] + g * g) - Math.sqrt(n[i])) / alpha;
+                        z[i] += g - sigma * modelData[0].values[i];
+                        n[i] += g * g;
+                        weights.values[i] += 1.0;
+                    }
+                } else {
+                    SparseVector sparseVector = (SparseVector) point.f0;
+                    for (int i = 0; i < sparseVector.indices.length; ++i) {
+                        int idx = sparseVector.indices[i];
+                        if (Math.abs(z[idx]) <= l1) {
+                            modelData[0].values[idx] = 0.0;
+                        } else {
+                            modelData[0].values[idx] =
+                                    ((z[idx] < 0 ? -1 : 1) * l1 - z[idx])
+                                            / ((beta + Math.sqrt(n[idx])) / alpha + l2);
+                        }
+                        p += modelData[0].values[idx] * sparseVector.values[i];
+                    }
+                    p = 1 / (1 + Math.exp(-p));
+                    for (int i = 0; i < sparseVector.indices.length; ++i) {
+                        int idx = sparseVector.indices[i];
+                        double g = (p - point.f1) * sparseVector.values[i];
+                        double sigma = (Math.sqrt(n[idx] + g * g) - Math.sqrt(n[idx])) / alpha;
+                        z[idx] += g - sigma * modelData[0].values[idx];
+                        n[idx] += g * g;
+                        weights.values[idx] += 1.0;

Review Comment:
   It seems that weights are only updated but not used.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlModel.java:
##########
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which classifies data using the model data computed by {@link Ftrl}. */
+public class FtrlModel implements Model<FtrlModel>, FtrlModelParams<FtrlModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public FtrlModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.DOUBLE, Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(), getPredictionCol(), "modelVersion"));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(LinearModelData.getModelDataStream(modelDataTable).broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, LinearModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            "MODEL_DATA_VERSION_GAUGE_KEY",
+                            (Gauge<String>) () -> Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws Exception {
+            Row dataPoint = streamRecord.getValue();
+            // todo : predict data
+            Vector features = (Vector) dataPoint.getField(featuresCol);
+            Tuple2<Double, DenseVector> predictionResult = predictRaw(features, coefficient);
+            output.collect(
+                    new StreamRecord<>(
+                            Row.join(dataPoint, Row.of(predictionResult.f0, modelDataVersion))));
+        }
+
+        @Override
+        public void processElement2(StreamRecord<LinearModelData> streamRecord) throws Exception {
+            LinearModelData modelData = streamRecord.getValue();
+
+            coefficient = modelData.coefficient;
+            modelDataVersion = modelData.modelVersion;
+            // System.out.println("update model...");
+            // todo : receive model data.
+            // Preconditions.checkArgument(modelData.centroids.length <= k);
+            // centroids = modelData.centroids;
+            // modelDataVersion++;
+            // for (Row dataPoint : bufferedPointsState.get()) {
+            //	processElement1(new StreamRecord<>(dataPoint));
+            // }
+            // bufferedPointsState.clear();
+        }
+    }
+
+    /**
+     * The main logic that predicts one input record.
+     *
+     * @param feature The input feature.
+     * @param coefficient The model parameters.
+     * @return The prediction label and the raw probabilities.
+     */
+    public static Tuple2<Double, DenseVector> predictRaw(Vector feature, DenseVector coefficient)

Review Comment:
   Methods like this are almost identical to that in `LogisticRegressionModel`. It could be better if we could reuse logics that have already been defined in `LogisticRegression` and `LogisticRegressionModel`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] {value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData map(DenseVector[] value) {
+                                                    return new LinearModelData(value[0], iter++);
+                                                }
+                                            })
+                                    .setParallelism(1)));
+        }
+    }
+
+    /** Gets vector data. */
+    public static class GetVectorData implements MapFunction<LinearModelData, DenseVector[]> {
+        @Override
+        public DenseVector[] map(LinearModelData value) throws Exception {
+            return new DenseVector[] {value.coefficient};
+        }
+    }
+
+    /**
+     * Operator that collects a LogisticRegressionModelData from each upstream subtask, and outputs
+     * the weight average of collected model data.
+     */
+    public static class FtrlGlobalReducer implements ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] newModelData) {
+            for (int i = 0; i < newModelData[0].size(); ++i) {
+                if ((modelData[1].values[i] + newModelData[1].values[i]) > 0.0) {
+                    newModelData[0].values[i] =
+                            (modelData[0].values[i] * modelData[1].values[i]
+                                            + newModelData[0].values[i] * newModelData[1].values[i])
+                                    / (modelData[1].values[i] + newModelData[1].values[i]);
+                }
+                newModelData[1].values[i] = modelData[1].values[i] + newModelData[1].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    public static class FtrlLocalUpdater extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<
+                    Tuple2<Vector, Double>[], DenseVector[], DenseVector[]> {
+        private ListState<Tuple2<Vector, Double>[]> localBatchDataState;
+        private ListState<DenseVector[]> modelDataState;
+        private double[] n;
+        private double[] z;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private DenseVector weights;
+
+        public FtrlLocalUpdater(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<Tuple2<Vector, Double>[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector[].class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Tuple2<Vector, Double>[]> pointsRecord)
+                throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Tuple2<Vector, Double>[]> pointsList =
+                    IteratorUtils.toList(localBatchDataState.get().iterator());
+            Tuple2<Vector, Double>[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Tuple2<Vector, Double> point : points) {
+                if (n == null) {
+                    n = new double[point.f0.size()];
+                    z = new double[n.length];
+                    weights = new DenseVector(n.length);
+                }
+
+                double p = 0.0;
+                Arrays.fill(weights.values, 0.0);
+                if (point.f0 instanceof DenseVector) {

Review Comment:
   It seems that some logic in `if` and in `else` are the same. Shall we move these codes out of the if-else condition?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] {value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData map(DenseVector[] value) {
+                                                    return new LinearModelData(value[0], iter++);
+                                                }
+                                            })
+                                    .setParallelism(1)));
+        }
+    }
+
+    /** Gets vector data. */
+    public static class GetVectorData implements MapFunction<LinearModelData, DenseVector[]> {
+        @Override
+        public DenseVector[] map(LinearModelData value) throws Exception {
+            return new DenseVector[] {value.coefficient};
+        }
+    }
+
+    /**
+     * Operator that collects a LogisticRegressionModelData from each upstream subtask, and outputs
+     * the weight average of collected model data.
+     */
+    public static class FtrlGlobalReducer implements ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] newModelData) {
+            for (int i = 0; i < newModelData[0].size(); ++i) {
+                if ((modelData[1].values[i] + newModelData[1].values[i]) > 0.0) {
+                    newModelData[0].values[i] =
+                            (modelData[0].values[i] * modelData[1].values[i]
+                                            + newModelData[0].values[i] * newModelData[1].values[i])
+                                    / (modelData[1].values[i] + newModelData[1].values[i]);
+                }
+                newModelData[1].values[i] = modelData[1].values[i] + newModelData[1].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    public static class FtrlLocalUpdater extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<
+                    Tuple2<Vector, Double>[], DenseVector[], DenseVector[]> {
+        private ListState<Tuple2<Vector, Double>[]> localBatchDataState;
+        private ListState<DenseVector[]> modelDataState;
+        private double[] n;
+        private double[] z;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private DenseVector weights;
+
+        public FtrlLocalUpdater(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+
+            TypeInformation<Tuple2<Vector, Double>[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector[].class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Tuple2<Vector, Double>[]> pointsRecord)
+                throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Tuple2<Vector, Double>[]> pointsList =
+                    IteratorUtils.toList(localBatchDataState.get().iterator());
+            Tuple2<Vector, Double>[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Tuple2<Vector, Double> point : points) {
+                if (n == null) {
+                    n = new double[point.f0.size()];
+                    z = new double[n.length];
+                    weights = new DenseVector(n.length);
+                }
+
+                double p = 0.0;
+                Arrays.fill(weights.values, 0.0);
+                if (point.f0 instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) point.f0;
+                    for (int i = 0; i < denseVector.size(); ++i) {
+                        if (Math.abs(z[i]) <= l1) {
+                            modelData[0].values[i] = 0.0;
+                        } else {
+                            modelData[0].values[i] =
+                                    ((z[i] < 0 ? -1 : 1) * l1 - z[i])
+                                            / ((beta + Math.sqrt(n[i])) / alpha + l2);
+                        }
+                        p += modelData[0].values[i] * denseVector.values[i];
+                    }
+                    p = 1 / (1 + Math.exp(-p));
+                    for (int i = 0; i < denseVector.size(); ++i) {
+                        double g = (p - point.f1) * denseVector.values[i];
+                        double sigma = (Math.sqrt(n[i] + g * g) - Math.sqrt(n[i])) / alpha;
+                        z[i] += g - sigma * modelData[0].values[i];
+                        n[i] += g * g;
+                        weights.values[i] += 1.0;
+                    }
+                } else {
+                    SparseVector sparseVector = (SparseVector) point.f0;
+                    for (int i = 0; i < sparseVector.indices.length; ++i) {
+                        int idx = sparseVector.indices[i];
+                        if (Math.abs(z[idx]) <= l1) {
+                            modelData[0].values[idx] = 0.0;
+                        } else {
+                            modelData[0].values[idx] =
+                                    ((z[idx] < 0 ? -1 : 1) * l1 - z[idx])
+                                            / ((beta + Math.sqrt(n[idx])) / alpha + l2);
+                        }
+                        p += modelData[0].values[idx] * sparseVector.values[i];
+                    }
+                    p = 1 / (1 + Math.exp(-p));
+                    for (int i = 0; i < sparseVector.indices.length; ++i) {
+                        int idx = sparseVector.indices[i];
+                        double g = (p - point.f1) * sparseVector.values[i];
+                        double sigma = (Math.sqrt(n[idx] + g * g) - Math.sqrt(n[idx])) / alpha;
+                        z[idx] += g - sigma * modelData[0].values[idx];
+                        n[idx] += g * g;
+                        weights.values[idx] += 1.0;
+                    }
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector[] {modelData[0], weights}));
+        }
+    }
+
+    /** Parses samples of input data. */
+    public static class ParseSample extends RichMapFunction<Row, Tuple2<Vector, Double>> {
+        private static final long serialVersionUID = 3738888745125082777L;

Review Comment:
   Is variables like this required?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r881275210


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBeta(),
+                        getReg(),
+                        getElasticNet(),
+                        getModelSaveInterval());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private long modelVersion = 1L;
+        private final int modelSaveInterval;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double reg,
+                double elasticNet,
+                int modelSaveInterval) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+            this.modelSaveInterval = modelSaveInterval;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .filter(
+                                    new FilterFunction<DenseVector>() {
+                                        private int step = 0;
+
+                                        @Override
+                                        public boolean filter(DenseVector denseVector) {
+                                            step++;
+                                            return step % modelSaveInterval == 0;
+                                        }
+                                    })
+                            .setParallelism(1)
+                            .map(
+                                    (MapFunction<DenseVector, LogisticRegressionModelData>)
+                                            value ->
+                                                    new LogisticRegressionModelData(
+                                                            value, modelVersion++))
+                            .setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData));
+        }
+    }
+
+    /** Updates model. */
+    private static class UpdateModel extends AbstractStreamOperator<DenseVector>
+            implements OneInputStreamOperator<DenseVector[], DenseVector> {
+        private ListState<double[]> nParamState;
+        private ListState<double[]> zParamState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private double[] nParam;
+        private double[] zParam;
+
+        public UpdateModel(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            nParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("nParamState", double[].class));
+            zParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("zParamState", double[].class));
+        }
+
+        @Override
+        public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            DenseVector[] gradientInfo = streamRecord.getValue();
+            double[] coefficient = gradientInfo[2].values;
+            double[] g = gradientInfo[0].values;
+            for (int i = 0; i < g.length; ++i) {
+                if (gradientInfo[1].values[i] != 0.0) {
+                    g[i] = g[i] / gradientInfo[1].values[i];
+                }
+            }
+            if (zParam == null) {
+                zParam = new double[g.length];
+                nParam = new double[g.length];
+                nParamState.add(nParam);
+                zParamState.add(zParam);
+            }
+
+            for (int i = 0; i < zParam.length; ++i) {
+                double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - Math.sqrt(nParam[i])) / alpha;
+                zParam[i] += g[i] - sigma * coefficient[i];
+                nParam[i] += g[i] * g[i];
+
+                if (Math.abs(zParam[i]) <= l1) {
+                    coefficient[i] = 0.0;
+                } else {
+                    coefficient[i] =
+                            ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i])
+                                    / ((beta + Math.sqrt(nParam[i])) / alpha + l2);
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector(coefficient)));
+        }
+    }
+
+    private static class CalculateLocalGradient extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<Row[], DenseVector, DenseVector[]> {
+        private ListState<DenseVector> modelDataState;
+        private ListState<Row[]> localBatchDataState;
+        private double[] gradient;
+        private double[] weight;
+        private int[] denseVectorIndices;
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector.class));
+            TypeInformation<Row[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class));
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row[]> pointsRecord) throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            calculateGradient();
+        }
+
+        private void calculateGradient() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+            DenseVector modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Row[]> pointsList = IteratorUtils.toList(localBatchDataState.get().iterator());
+            Row[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Row point : points) {
+                Vector vec = point.getFieldAs(0);
+                double label = point.getFieldAs(1);
+                if (gradient == null) {
+                    gradient = new double[vec.size()];
+                    weight = new double[gradient.length];
+                    if (vec instanceof DenseVector) {
+                        denseVectorIndices = new int[vec.size()];
+                        for (int i = 0; i < denseVectorIndices.length; ++i) {
+                            denseVectorIndices[i] = i;
+                        }
+                    }
+                }
+
+                int[] indices;
+                double[] values;
+                if (vec instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) vec;
+                    indices = denseVectorIndices;
+                    values = denseVector.values;
+                } else {
+                    SparseVector sparseVector = (SparseVector) vec;
+                    indices = sparseVector.indices;
+                    values = sparseVector.values;
+                }
+                double p = 0.0;
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    p += modelData.values[idx] * values[i];
+                }
+                p = 1 / (1 + Math.exp(-p));
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    gradient[idx] += (p - label) * values[i];
+                    weight[idx] += 1.0;
+                }
+            }
+
+            if (points.length > 0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new DenseVector[] {
+                                    new DenseVector(gradient),
+                                    new DenseVector(weight),
+                                    (getRuntimeContext().getIndexOfThisSubtask() == 0)
+                                            ? modelData
+                                            : null
+                                }));
+            }
+            Arrays.fill(gradient, 0.0);
+            Arrays.fill(weight, 0.0);
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector> modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            calculateGradient();
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(

Review Comment:
   Here,  I just save meta and initModelData.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r884283585


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasWeightCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasReg<T>,
+                HasElasticNet<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =
+            new DoubleParam("alpha", "The parameter alpha of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getAlpha() {
+        return get(ALPHA);
+    }
+
+    default T setAlpha(Double value) {
+        return set(ALPHA, value);
+    }
+
+    Param<Double> BETA =
+            new DoubleParam("alpha", "The parameter beta of ftrl.", 0.1, ParamValidators.gt(0.0));

Review Comment:
   I have add ftrl doc in the doc of Online logistic regression.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r854941445


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasL2.java:
##########
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared L2 param. */
+public interface HasL2<T> extends WithParams<T> {
+    Param<Double> L_2 = new DoubleParam("l2", "The l2 param.", 0.1, ParamValidators.gt(0.0));

Review Comment:
   L1 and L2 are the hyper parameters for ftrl. It has different meaning from regParam. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator of Ftrl

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r847915850


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LinearModelData.java:
##########
@@ -37,60 +38,68 @@
 import java.io.OutputStream;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel}, {@link FtrlModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
-public class LogisticRegressionModelData {
+public class LinearModelData {

Review Comment:
   LinearModelData is shared by ftrl, lr, svm, linearReg. For these algorithm has the same model data format. We need to talk about this rename. If rename, I agree to move this linear model data to a common place. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r868872809


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasL2.java:
##########
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared L2 param. */
+public interface HasL2<T> extends WithParams<T> {
+    Param<Double> L_2 = new DoubleParam("l2", "The l2 param.", 0.1, ParamValidators.gt(0.0));

Review Comment:
   After some offline discussion, @weibozhao  and I agree that we could reuse the exising `HasReg` and `HasElasticNet` param, without introducing `HasL1` and `HasL2` in this implementation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator of Ftrl

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r847916030


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LinearModelData.java:
##########
@@ -37,60 +38,68 @@
 import java.io.OutputStream;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel}, {@link FtrlModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
-public class LogisticRegressionModelData {
+public class LinearModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;

Review Comment:
   We need  talk about it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator of Ftrl

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r848984962


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlModelParams.java:
##########
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+
+/**
+ * Params for {@link FtrlModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlModelParams<T>
+        extends HasFeaturesCol<T>, HasPredictionCol<T>, HasRawPredictionCol<T> {}

Review Comment:
    FTRL is not an online version of LogisticRegression. It just has the same model and predict process. The update processes are different.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on pull request #83: [FLINK-27170] Add Transformer and Estimator for Ftrl

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on PR #83:
URL: https://github.com/apache/flink-ml/pull/83#issuecomment-1101907239

   Should we make `FTRL` an optimizer or an algorithm? As I see, FTRL is an optimizer that could be used to optimize Logistic Regression, just as SGD.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r854909621


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -37,17 +37,19 @@
 import java.io.OutputStream;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;

Review Comment:
   see above comments.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r886644705


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);

Review Comment:
   If the init model version is 5, then will you create a model with version smaller than 5?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -18,40 +18,86 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.serialization.Encoder;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
+import java.util.Random;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;
 
-    public LogisticRegressionModelData(DenseVector coefficient) {
+    /* todo : modelVersion is a new idea to manage model, other algorithm if needed should add this parameter later. */
+    public LogisticRegressionModelData(DenseVector coefficient, long modelVersion) {
         this.coefficient = coefficient;
+        this.modelVersion = modelVersion;
     }
 
     public LogisticRegressionModelData() {}
 
+    public LogisticRegressionModelData(DenseVector coefficient) {
+        this.coefficient = coefficient;
+        this.modelVersion = 0L;
+    }
+
+    /**
+     * Generates a Table containing a {@link LogisticRegressionModelData} instance with randomly
+     * generated coefficient.
+     *
+     * @param tEnv The environment where to create the table.
+     * @param dim The size of generated coefficient.
+     * @param seed Random seed.
+     */
+    public static Table generateRandomModelData(StreamTableEnvironment tEnv, int dim, int seed) {
+        StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
+        return tEnv.fromDataStream(env.fromElements(1).map(new GenerateRandomModel(dim, seed)));
+    }
+
+    private static class GenerateRandomModel

Review Comment:
   This comment is marked as resolved but seems not addressed. Did you forget about this?



##########
flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java:
##########
@@ -74,6 +75,15 @@ public void testReduce() throws Exception {
         assertArrayEquals(new long[] {190L}, sum.stream().mapToLong(Long::longValue).toArray());
     }
 
+    @Test
+    public void testGenerateBatchData() throws Exception {
+        DataStream<Long> dataStream =
+                env.fromParallelCollection(new NumberSequenceIterator(0L, 19L), Types.LONG);
+        DataStream<Long[]> result = DataStreamUtils.generateBatchData(dataStream, 2, 4);
+        List<Long[]> batches = IteratorUtils.toList(result.executeAndCollect());
+        assertEquals(10, batches.size());

Review Comment:
   Could you please check the size of each local batch?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +193,63 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A function that generate several data batches and distribute them to downstream operator.
+     *
+     * @param <T> Data type of batch data.
+     */
+    public static <T> DataStream<T[]> generateBatchData(

Review Comment:
   This comment is not addressed nut marked as resolved. Did you forget to address it?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 merged pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 merged PR #83:
URL: https://github.com/apache/flink-ml/pull/83


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r882371295


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), getElasticNet());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     *
+     * <p>todo: makes ftrl to be a common optimizer and place it in org.apache.flink.ml.common in
+     * future.
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(
+                int batchSize, double alpha, double beta, double reg, double elasticNet) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData.map(new CreateLrModelData()).setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class CreateLrModelData
+            implements MapFunction<DenseVector, LogisticRegressionModelData>, CheckpointedFunction {
+        private Long modelVersion = 1L;
+        private transient ListState<Long> modelVersionState;
+
+        @Override
+        public LogisticRegressionModelData map(DenseVector denseVector) throws Exception {
+            return new LogisticRegressionModelData(denseVector, modelVersion++);
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext functionSnapshotContext)
+                throws Exception {
+            modelVersionState.update(Collections.singletonList(modelVersion));
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) throws Exception {
+            modelVersionState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelVersionState", Long.class));
+        }
+    }
+
+    /** Updates model. */
+    private static class UpdateModel extends AbstractStreamOperator<DenseVector>
+            implements OneInputStreamOperator<DenseVector[], DenseVector> {
+        private ListState<double[]> nParamState;
+        private ListState<double[]> zParamState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private double[] nParam;
+        private double[] zParam;
+
+        public UpdateModel(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            nParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("nParamState", double[].class));
+            zParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("zParamState", double[].class));
+        }
+
+        @Override
+        public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            DenseVector[] gradientInfo = streamRecord.getValue();
+            double[] coefficient = gradientInfo[2].values;
+            double[] g = gradientInfo[0].values;
+            for (int i = 0; i < g.length; ++i) {
+                if (gradientInfo[1].values[i] != 0.0) {
+                    g[i] = g[i] / gradientInfo[1].values[i];
+                }
+            }
+            if (zParam == null) {
+                zParam = new double[g.length];
+                nParam = new double[g.length];
+                nParamState.add(nParam);
+                zParamState.add(zParam);
+            }
+
+            for (int i = 0; i < zParam.length; ++i) {
+                double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - Math.sqrt(nParam[i])) / alpha;
+                zParam[i] += g[i] - sigma * coefficient[i];
+                nParam[i] += g[i] * g[i];
+
+                if (Math.abs(zParam[i]) <= l1) {
+                    coefficient[i] = 0.0;
+                } else {
+                    coefficient[i] =
+                            ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i])
+                                    / ((beta + Math.sqrt(nParam[i])) / alpha + l2);
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector(coefficient)));
+        }
+    }
+
+    private static class CalculateLocalGradient extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<Row[], DenseVector, DenseVector[]> {
+        private ListState<DenseVector> modelDataState;
+        private ListState<Row[]> localBatchDataState;
+        private double[] gradient;
+        private double[] weight;
+        private int[] denseVectorIndices;
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector.class));
+            TypeInformation<Row[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class));
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row[]> pointsRecord) throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            calculateGradient();
+        }
+
+        private void calculateGradient() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+            DenseVector modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Row[]> pointsList = IteratorUtils.toList(localBatchDataState.get().iterator());
+            Row[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Row point : points) {
+                Vector vec = point.getFieldAs(0);
+                double label = point.getFieldAs(1);
+                if (gradient == null) {
+                    gradient = new double[vec.size()];
+                    weight = new double[gradient.length];
+                    if (vec instanceof DenseVector) {
+                        denseVectorIndices = new int[vec.size()];
+                        for (int i = 0; i < denseVectorIndices.length; ++i) {
+                            denseVectorIndices[i] = i;
+                        }
+                    }
+                }
+
+                int[] indices;
+                double[] values;
+                if (vec instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) vec;
+                    indices = denseVectorIndices;
+                    values = denseVector.values;
+                } else {
+                    SparseVector sparseVector = (SparseVector) vec;
+                    indices = sparseVector.indices;
+                    values = sparseVector.values;
+                }
+                double p = 0.0;
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    p += modelData.values[idx] * values[i];
+                }
+                p = 1 / (1 + Math.exp(-p));
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    gradient[idx] += (p - label) * values[i];
+                    weight[idx] += 1.0;
+                }
+            }
+
+            if (points.length > 0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new DenseVector[] {
+                                    new DenseVector(gradient),
+                                    new DenseVector(weight),

Review Comment:
   The variable `weight` is a bit confusing here. Is `featureImportance` a better name? (I am not sure but this is the best name I can come up with)
   Also, could you add the Java doc here to say that if the input is dense, it is the same with Tensorflow implementation, while different for sparse input? It would be great if you can also explain why different for sparse input.
   
   
   Moreover, for outputing the models, we could use another `OutputTag<LogisticRegressionModelData>` and do not mess it up with the communication. Please check out SGD#Line292 for example.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r887419273


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +193,63 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * A function that generate several data batches and distribute them to downstream operator.
+     *
+     * @param <T> Data type of batch data.
+     */
+    public static <T> DataStream<T[]> generateBatchData(

Review Comment:
   I have addressed already. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r854856268


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the OnlineLogisticRegression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Online_machine_learning.
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector[]>)
+                                value -> new DenseVector[] {value.coefficient});
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBETA(),
+                        getL1(),
+                        getL2(),
+                        getFeaturesCol(),
+                        getLabelCol());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData),
+                                DataStreamList.of(tEnv.toDataStream(inputs[0])),
+                                body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl algorithm. */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private final String featureCol;
+        private final String labelCol;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double l1,
+                double l2,
+                String featureCol,
+                String labelCol) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+            this.featureCol = featureCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new ModelDataLocalUpdater(
+                                            alpha, beta, l1, l2, featureCol, labelCol))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+            DataStream<DenseVector[]> feedbackModelData =
+                    newModelData
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> {
+                                                double[] z = value[1].values;
+                                                double[] n = value[2].values;
+                                                for (int i = 0; i < z.length; ++i) {
+                                                    if (Math.abs(z[i]) <= l1) {
+                                                        value[0].values[i] = 0.0;
+                                                    } else {
+                                                        value[0].values[i] =
+                                                                ((z[i] < 0 ? -1 : 1) * l1 - z[i])
+                                                                        / ((beta + Math.sqrt(n[i]))
+                                                                                        / alpha
+                                                                                + l2);
+                                                    }
+                                                }
+                                                return new DenseVector[] {value[0]};
+                                            })
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .map(
+                                    (MapFunction<DenseVector[], LogisticRegressionModelData>)
+                                            value ->
+                                                    new LogisticRegressionModelData(
+                                                            value[0], System.nanoTime()))

Review Comment:
   Can we add the model version when the infra for model version is ready?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r860550025


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java:
##########
@@ -162,9 +164,17 @@ public Row map(Row dataPoint) {
      * @param coefficient The model parameters.
      * @return The prediction label and the raw probabilities.
      */
-    private static Tuple2<Double, DenseVector> predictRaw(
-            DenseVector feature, DenseVector coefficient) {
-        double dotValue = BLAS.dot(feature, coefficient);
+    protected static Tuple2<Double, DenseVector> predictRaw(
+            Vector feature, DenseVector coefficient) {
+        double dotValue = 0.0;
+        if (feature instanceof SparseVector) {
+            SparseVector svec = (SparseVector) feature;

Review Comment:
   Would it be more useful to move this logic to infra, by adding the method `public static double dot(Vector x, DenseVector y)` in `BLAS.java`?



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java:
##########
@@ -0,0 +1,444 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
+import org.apache.flink.ml.feature.MinMaxScalerTest;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */
+public class OnlineLogisticRegressionTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainDenseTable;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private static final double[] ONE_ARRAY = new double[] {1.0, 1.0, 1.0};
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 3}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {7, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.));
+
+    private Table initDenseModel;
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream).as(FEATURE_COL, LABEL_COL);
+        initDenseModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+    }
+
+    @Test
+    public void testFit() throws Exception {
+        Table onlinePredictTable = getTable(1, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testFitWithInitLrModel() throws Exception {
+        Table initLrModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+        Table onlineTrainTable = getTable(50, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initLrModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testDenseFitAndPredict() throws Exception {
+        Table onlineTrainTable = getTable(2, 2000, TRAIN_DENSE_ROWS, 2, false);
+        Table onlinePredictTable = getTable(2, 3000, PREDICT_DENSE_ROWS, 2, false);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);
+    }
+
+    @Test
+    public void testSparseFitAndPredict() throws Exception {
+        double[] doubleArray = new double[] {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
+        Table initSparseModel =
+                tEnv.fromDataStream(env.fromElements(Row.of(new DenseVector(doubleArray), 0L)));
+        Table onlineTrainTable = getTable(5, 2000, TRAIN_SPARSE_ROWS, 4, true);
+        Table onlinePredictTable = getTable(5, 3000, PREDICT_SPARSE_ROWS, 4, true);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initSparseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);

Review Comment:
   This test would generate 2000 / 500 = 4 model data versions. And a given element in the `onlinePredictTable` might be transformed by any one of these 4 model data versions, right?
   
   If so, the prediction result is non-deterministic and we won't be able to deterministically verify the numerical values of the prediction result. This might explain why this PR has test failures.
   
   The `OnlineKMeansTest::testFitAndPredict` is able to address this problem using e.g. `waitModelDataUpdate()`. 
   
   Could you try to follow the approach used in `OnlineKMeansTest` and make sure all tests added in this PR could succeed?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java:
##########
@@ -342,17 +342,18 @@ public DenseVector map(Row row) {
     /**
      * An operator that splits a global batch into evenly-sized local batches, and distributes them
      * to downstream operator.
+     *
+     * @param <T> Data type of batch data.
      */
-    private static class GlobalBatchSplitter
-            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+    public static class GlobalBatchSplitter<T> implements FlatMapFunction<T[], T[]> {

Review Comment:
   It seems a bit straight to let `OnlineLogisticRegression` reference a class defined in `OnlineKMeans` given that these are two different algorithms.
   
   How about moving this class to `DataStreamUtils.java`?
   
   Same for `GlobalBatchCreator`.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java:
##########
@@ -0,0 +1,444 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
+import org.apache.flink.ml.feature.MinMaxScalerTest;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */
+public class OnlineLogisticRegressionTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainDenseTable;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private static final double[] ONE_ARRAY = new double[] {1.0, 1.0, 1.0};
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 3}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {7, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.));
+
+    private Table initDenseModel;
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream).as(FEATURE_COL, LABEL_COL);
+        initDenseModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+    }
+
+    @Test
+    public void testFit() throws Exception {
+        Table onlinePredictTable = getTable(1, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testFitWithInitLrModel() throws Exception {
+        Table initLrModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+        Table onlineTrainTable = getTable(50, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initLrModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testDenseFitAndPredict() throws Exception {
+        Table onlineTrainTable = getTable(2, 2000, TRAIN_DENSE_ROWS, 2, false);
+        Table onlinePredictTable = getTable(2, 3000, PREDICT_DENSE_ROWS, 2, false);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];

Review Comment:
   Since `setPredictionCol` is already tested in `testParam()`, would it be simpler to skip calling `setPredictionCol(PREDICT_COL)` in this test?
   
   Same for other invocations of `setFeaturesCol(FEATURE_COL)` and `setLabelCol(LABEL_COL)` in this class.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java:
##########
@@ -78,7 +78,7 @@ public class MinMaxScalerTest {
                             Vectors.dense(0.75, 0.225)));
 
     /** Note: this comparator imposes orderings that are inconsistent with equals. */
-    private static int compare(DenseVector first, DenseVector second) {
+    public static int compare(DenseVector first, DenseVector second) {

Review Comment:
   Since this static method is used by multiple test classes, how about moving this method to `TestUtils.java`?



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java:
##########
@@ -0,0 +1,444 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
+import org.apache.flink.ml.feature.MinMaxScalerTest;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */
+public class OnlineLogisticRegressionTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainDenseTable;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private static final double[] ONE_ARRAY = new double[] {1.0, 1.0, 1.0};
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 3}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {7, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.));
+
+    private Table initDenseModel;
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream).as(FEATURE_COL, LABEL_COL);
+        initDenseModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+    }
+
+    @Test
+    public void testFit() throws Exception {
+        Table onlinePredictTable = getTable(1, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testFitWithInitLrModel() throws Exception {
+        Table initLrModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+        Table onlineTrainTable = getTable(50, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initLrModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testDenseFitAndPredict() throws Exception {
+        Table onlineTrainTable = getTable(2, 2000, TRAIN_DENSE_ROWS, 2, false);
+        Table onlinePredictTable = getTable(2, 3000, PREDICT_DENSE_ROWS, 2, false);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);
+    }
+
+    @Test
+    public void testSparseFitAndPredict() throws Exception {
+        double[] doubleArray = new double[] {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
+        Table initSparseModel =
+                tEnv.fromDataStream(env.fromElements(Row.of(new DenseVector(doubleArray), 0L)));
+        Table onlineTrainTable = getTable(5, 2000, TRAIN_SPARSE_ROWS, 4, true);
+        Table onlinePredictTable = getTable(5, 3000, PREDICT_SPARSE_ROWS, 4, true);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initSparseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);
+    }
+
+    @Test
+    public void testParam() {

Review Comment:
   nits: could we order the methods in this class to be more consistent with test methods in other test classes, e.g. move the `testParam()` to be the first method in this class?



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java:
##########
@@ -0,0 +1,444 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
+import org.apache.flink.ml.feature.MinMaxScalerTest;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */
+public class OnlineLogisticRegressionTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainDenseTable;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private static final double[] ONE_ARRAY = new double[] {1.0, 1.0, 1.0};
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 3}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {7, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.));
+
+    private Table initDenseModel;
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream).as(FEATURE_COL, LABEL_COL);
+        initDenseModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+    }
+
+    @Test
+    public void testFit() throws Exception {
+        Table onlinePredictTable = getTable(1, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testFitWithInitLrModel() throws Exception {
+        Table initLrModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+        Table onlineTrainTable = getTable(50, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initLrModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testDenseFitAndPredict() throws Exception {
+        Table onlineTrainTable = getTable(2, 2000, TRAIN_DENSE_ROWS, 2, false);
+        Table onlinePredictTable = getTable(2, 3000, PREDICT_DENSE_ROWS, 2, false);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);
+    }
+
+    @Test
+    public void testSparseFitAndPredict() throws Exception {
+        double[] doubleArray = new double[] {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
+        Table initSparseModel =
+                tEnv.fromDataStream(env.fromElements(Row.of(new DenseVector(doubleArray), 0L)));
+        Table onlineTrainTable = getTable(5, 2000, TRAIN_SPARSE_ROWS, 4, true);
+        Table onlinePredictTable = getTable(5, 3000, PREDICT_SPARSE_ROWS, 4, true);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initSparseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);
+    }
+
+    @Test
+    public void testParam() {
+        OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression();
+        Assert.assertEquals("features", onlineLogisticRegression.getFeaturesCol());
+        Assert.assertEquals("count", onlineLogisticRegression.getBatchStrategy());
+        Assert.assertEquals("label", onlineLogisticRegression.getLabelCol());
+        Assert.assertEquals(0.1, onlineLogisticRegression.getL1(), 1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getL2(), 1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getAlpha(), 1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getBeta(), 1.0e-5);
+        Assert.assertEquals(32, onlineLogisticRegression.getGlobalBatchSize());
+
+        onlineLogisticRegression
+                .setFeaturesCol("test_feature")
+                .setLabelCol("test_label")
+                .setGlobalBatchSize(5)
+                .setL1(0.25)
+                .setL2(0.25)
+                .setAlpha(0.25)
+                .setBeta(0.25);
+
+        Assert.assertEquals("test_feature", onlineLogisticRegression.getFeaturesCol());
+        Assert.assertEquals("test_label", onlineLogisticRegression.getLabelCol());
+        Assert.assertEquals(0.25, onlineLogisticRegression.getL1(), 1.0e-5);
+        Assert.assertEquals(0.25, onlineLogisticRegression.getL2(), 1.0e-5);
+        Assert.assertEquals(0.25, onlineLogisticRegression.getAlpha(), 1.0e-5);
+        Assert.assertEquals(0.25, onlineLogisticRegression.getBeta(), 1.0e-5);
+        Assert.assertEquals(5, onlineLogisticRegression.getGlobalBatchSize());
+
+        OnlineLogisticRegressionModel onlineLogisticRegressionModel =
+                new OnlineLogisticRegressionModel();
+        Assert.assertEquals("features", onlineLogisticRegressionModel.getFeaturesCol());
+        Assert.assertEquals("modelVersion", onlineLogisticRegressionModel.getModelVersionCol());
+        Assert.assertEquals("prediction", onlineLogisticRegressionModel.getPredictionCol());
+        Assert.assertEquals("rawPrediction", onlineLogisticRegressionModel.getRawPredictionCol());
+
+        onlineLogisticRegressionModel
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("pred")
+                .setModelVersionCol("version")
+                .setRawPredictionCol("raw");
+
+        Assert.assertEquals("test_feature", onlineLogisticRegressionModel.getFeaturesCol());
+        Assert.assertEquals("version", onlineLogisticRegressionModel.getModelVersionCol());
+        Assert.assertEquals("pred", onlineLogisticRegressionModel.getPredictionCol());
+        Assert.assertEquals("raw", onlineLogisticRegressionModel.getRawPredictionCol());
+    }
+
+    @Test
+    public void testBatchSizeLessThanParallelism() throws Exception {
+        Table onlinePredictTable = getTable(1, 20, TRAIN_DENSE_ROWS, 4, true);
+        try {
+            new OnlineLogisticRegression()
+                    .setFeaturesCol(FEATURE_COL)
+                    .setInitialModelData(initDenseModel)
+                    .setGlobalBatchSize(2)
+                    .setLabelCol(LABEL_COL)
+                    .fit(onlinePredictTable);
+            Assert.fail("Expected IllegalStateException");
+        } catch (Exception e) {
+            Throwable exception = e;
+            while (exception.getCause() != null) {
+                exception = exception.getCause();
+            }
+            Assert.assertEquals(IllegalStateException.class, exception.getClass());
+            Assert.assertEquals(
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.",
+                    exception.getMessage());
+        }
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        Table onlineTrainTable = getTable(2, 2000, TRAIN_DENSE_ROWS, 2, false);
+        Table onlinePredictTable = getTable(2, 3000, PREDICT_DENSE_ROWS, 2, false);
+        OnlineLogisticRegression onlineLogisticRegression =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL);
+        OnlineLogisticRegression loadedOnlineLogisticRegression =
+                StageTestUtils.saveAndReload(
+                        tEnv, onlineLogisticRegression, tempFolder.newFolder().getAbsolutePath());
+        OnlineLogisticRegressionModel onlineLogisticRegressionModel =
+                loadedOnlineLogisticRegression.fit(onlineTrainTable);
+        Table resultTable =
+                onlineLogisticRegressionModel.setPredictionCol(PREDICT_COL)
+                        .transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);
+        OnlineLogisticRegressionModel loadedOnlineLogisticRegressionModel =
+                StageTestUtils.saveAndReload(
+                        tEnv,
+                        onlineLogisticRegressionModel,
+                        tempFolder.newFolder().getAbsolutePath());
+        Table resultTableWithLoadedModel =
+                loadedOnlineLogisticRegressionModel.setPredictionCol(PREDICT_COL)
+                        .transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTableWithLoadedModel, -1);
+    }
+
+    @Test
+    public void testGetModelData() throws Exception {
+        Table onlineTrainTable = getTable(2, 2000, TRAIN_DENSE_ROWS, 1, false);
+        List<DenseVector> expected =
+                new ArrayList<>(
+                        Arrays.asList(
+                                Vectors.dense(0.6094293839451556, -0.3535110997464949),
+                                Vectors.dense(0.8817781161262602, -0.6045148530476719),
+                                Vectors.dense(1.0802504223028735, -0.7809336961447708),
+                                Vectors.dense(1.236292181150552, -0.9166121469926248)));
+        OnlineLogisticRegression onlineLogisticRegression =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL);
+        OnlineLogisticRegressionModel onlineLogisticRegressionModel =
+                onlineLogisticRegression.fit(onlineTrainTable);
+        Table modelData = onlineLogisticRegressionModel.getModelData()[0];
+        DataStream<DenseVector> dataStream =
+                tEnv.toDataStream(modelData)
+                        .map((MapFunction<Row, DenseVector>) value -> value.getFieldAs(0));
+        List<DenseVector> result = IteratorUtils.toList(dataStream.executeAndCollect());
+        result.sort((o1, o2) -> MinMaxScalerTest.compare(o1, o2));
+        Assert.assertEquals(expected, result);
+    }
+
+    @Test
+    public void testSetModelData() throws Exception {
+        Table onlinePredictTable = getTable(1, 1, TRAIN_DENSE_ROWS, 1, false);
+
+        OnlineLogisticRegressionModel onlineLogisticRegressionModel =
+                new OnlineLogisticRegressionModel().setFeaturesCol(FEATURE_COL);
+
+        Table modelData =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.5, 0.1}), 0L)));
+        onlineLogisticRegressionModel.setModelData(modelData);
+
+        Row expected =
+                Row.of(
+                        Vectors.dense(0.1, 2.0),
+                        0.0,
+                        1.0,
+                        Vectors.dense(0.43782349911420193, 0.5621765008857981),
+                        0L);
+        DataStream<Row> results =
+                tEnv.toDataStream(onlineLogisticRegressionModel.transform(onlinePredictTable)[0]);
+        List<Row> resultList = IteratorUtils.toList(results.executeAndCollect());
+        Assert.assertEquals(expected, resultList.get(0));
+    }
+
+    private static void verifyPredictionResult(Table output, int expectedNum) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+        DataStream<Row> stream = tEnv.toDataStream(output);
+        List<Row> result = IteratorUtils.toList(stream.executeAndCollect());
+        Map<Long, Tuple2<Double, Double>> correctRatio = new HashMap<>();
+
+        for (Row row : result) {
+            long modelVersion = row.getFieldAs(MODEL_VERSION_COL);
+            Double pred = row.getFieldAs(PREDICT_COL);
+            Double label = row.getFieldAs(LABEL_COL);
+            if (correctRatio.containsKey(modelVersion)) {
+                Tuple2<Double, Double> t2 = correctRatio.get(modelVersion);
+                if (pred.equals(label)) {
+                    t2.f0 += 1.0;
+                }
+                t2.f1 += 1.0;
+            } else {
+                correctRatio.put(modelVersion, Tuple2.of(pred.equals(label) ? 1.0 : 0.0, 1.0));
+            }
+        }
+        int numModel = 0;
+        for (Long id : correctRatio.keySet()) {
+            assertEquals(1.0, correctRatio.get(id).f0 / correctRatio.get(id).f1, 1.0e-5);
+            numModel++;
+        }
+        if (expectedNum != -1) {
+            assertEquals(expectedNum, numModel);
+        }
+    }
+
+    /** Random selects samples with fixed size from a given sample list. */
+    private static class RandomSample implements SourceFunction<Row> {
+        private volatile boolean isRunning = true;
+        private final long timeInterval;
+        private final long numSample;
+        private final List<Row> data;
+
+        public RandomSample(long timeInterval, long numSample, List<Row> data) {
+            this.timeInterval = timeInterval;
+            this.numSample = numSample;
+            this.data = data;
+        }
+
+        @Override
+        public void run(SourceContext<Row> ctx) throws Exception {
+            int size = data.size();
+            for (int i = 0; i < numSample; ++i) {
+                int idx = i % size;
+                if (isRunning) {
+                    ctx.collect(data.get(idx));
+                    if (timeInterval > 0) {
+                        Thread.sleep(timeInterval);
+                    }
+                }
+            }
+        }
+
+        @Override
+        public void cancel() {
+            isRunning = false;
+        }
+    }
+
+    private Table getTable(
+            int timeInterval, int numSample, List<Row> data, int parallel, boolean isSparse) {

Review Comment:
   Would it be more readable to rename `parallel` as `sourceParallelism`?
   
   Would it be more readable to rename this method as `getSourceTable`?
   
   When the parallelism is 2, the last 50% of elements of `data` will never be emitted by the source function. Is this expected?
   
   Would you help explain the motivation to test source function with different parallelism? For example, would it be sufficient to only test `sourceParallelism = 2`?
   



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the OnlineLogisticRegression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Online_machine_learning.

Review Comment:
   It would be useful to provide sufficient reference that allows users to judge the performance of this algorithm and give credits to the original authors of the algorithm. The link to the wiki page above is probably not sufficient.
   
   How about the following Java doc? This is similar to the Java doc used by Spark's `KMeans`.
   
   ```
   /**
    * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan McMahan et al.
    *
    * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click prediction: a view from the trenches.</a>
    */
   ```



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasL1.java:
##########
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared L1 param. */
+public interface HasL1<T> extends WithParams<T> {
+    Param<Double> L_1 = new DoubleParam("l1", "The l1 param.", 0.1, ParamValidators.gt(0.0));

Review Comment:
   nits: would it be simpler to replace `L_1` with `L1`, so that the experience of referencing this variable is more consistent with the experience of specifying the parameter string `l1`?
   
   Would it be a bit better to say `L1 regularization parameter` in the Java doc? This would also be make the doc more consistent with other parameters' doc, which typically don't start the `the`.
   
   Same for `L_2`.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java:
##########
@@ -0,0 +1,444 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
+import org.apache.flink.ml.feature.MinMaxScalerTest;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */
+public class OnlineLogisticRegressionTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainDenseTable;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private static final double[] ONE_ARRAY = new double[] {1.0, 1.0, 1.0};
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 3}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 8}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.sparse(10, new int[] {1, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {0, 2, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                    Row.of(Vectors.sparse(10, new int[] {6, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {7, 8, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 7, 9}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.),
+                    Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.));
+
+    private Table initDenseModel;
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream).as(FEATURE_COL, LABEL_COL);
+        initDenseModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.0, 0.0}), 0L)));
+    }
+
+    @Test
+    public void testFit() throws Exception {
+        Table onlinePredictTable = getTable(1, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testFitWithInitLrModel() throws Exception {
+        Table initLrModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+        Table onlineTrainTable = getTable(50, 1000, TRAIN_DENSE_ROWS, 4, false);
+        Table models =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initLrModel)
+                        .setGlobalBatchSize(100)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable)
+                        .getModelData()[0];
+        List<Row> modelList = IteratorUtils.toList(tEnv.toDataStream(models).executeAndCollect());
+        assertEquals(10, modelList.size());
+    }
+
+    @Test
+    public void testDenseFitAndPredict() throws Exception {
+        Table onlineTrainTable = getTable(2, 2000, TRAIN_DENSE_ROWS, 2, false);
+        Table onlinePredictTable = getTable(2, 3000, PREDICT_DENSE_ROWS, 2, false);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);
+    }
+
+    @Test
+    public void testSparseFitAndPredict() throws Exception {
+        double[] doubleArray = new double[] {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1};
+        Table initSparseModel =
+                tEnv.fromDataStream(env.fromElements(Row.of(new DenseVector(doubleArray), 0L)));
+        Table onlineTrainTable = getTable(5, 2000, TRAIN_SPARSE_ROWS, 4, true);
+        Table onlinePredictTable = getTable(5, 3000, PREDICT_SPARSE_ROWS, 4, true);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initSparseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);
+    }
+
+    @Test
+    public void testParam() {
+        OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression();
+        Assert.assertEquals("features", onlineLogisticRegression.getFeaturesCol());
+        Assert.assertEquals("count", onlineLogisticRegression.getBatchStrategy());
+        Assert.assertEquals("label", onlineLogisticRegression.getLabelCol());
+        Assert.assertEquals(0.1, onlineLogisticRegression.getL1(), 1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getL2(), 1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getAlpha(), 1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getBeta(), 1.0e-5);
+        Assert.assertEquals(32, onlineLogisticRegression.getGlobalBatchSize());
+
+        onlineLogisticRegression
+                .setFeaturesCol("test_feature")
+                .setLabelCol("test_label")
+                .setGlobalBatchSize(5)
+                .setL1(0.25)
+                .setL2(0.25)
+                .setAlpha(0.25)
+                .setBeta(0.25);
+
+        Assert.assertEquals("test_feature", onlineLogisticRegression.getFeaturesCol());
+        Assert.assertEquals("test_label", onlineLogisticRegression.getLabelCol());
+        Assert.assertEquals(0.25, onlineLogisticRegression.getL1(), 1.0e-5);
+        Assert.assertEquals(0.25, onlineLogisticRegression.getL2(), 1.0e-5);
+        Assert.assertEquals(0.25, onlineLogisticRegression.getAlpha(), 1.0e-5);
+        Assert.assertEquals(0.25, onlineLogisticRegression.getBeta(), 1.0e-5);
+        Assert.assertEquals(5, onlineLogisticRegression.getGlobalBatchSize());
+
+        OnlineLogisticRegressionModel onlineLogisticRegressionModel =
+                new OnlineLogisticRegressionModel();
+        Assert.assertEquals("features", onlineLogisticRegressionModel.getFeaturesCol());
+        Assert.assertEquals("modelVersion", onlineLogisticRegressionModel.getModelVersionCol());
+        Assert.assertEquals("prediction", onlineLogisticRegressionModel.getPredictionCol());
+        Assert.assertEquals("rawPrediction", onlineLogisticRegressionModel.getRawPredictionCol());
+
+        onlineLogisticRegressionModel
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("pred")
+                .setModelVersionCol("version")
+                .setRawPredictionCol("raw");
+
+        Assert.assertEquals("test_feature", onlineLogisticRegressionModel.getFeaturesCol());
+        Assert.assertEquals("version", onlineLogisticRegressionModel.getModelVersionCol());
+        Assert.assertEquals("pred", onlineLogisticRegressionModel.getPredictionCol());
+        Assert.assertEquals("raw", onlineLogisticRegressionModel.getRawPredictionCol());
+    }
+
+    @Test
+    public void testBatchSizeLessThanParallelism() throws Exception {
+        Table onlinePredictTable = getTable(1, 20, TRAIN_DENSE_ROWS, 4, true);
+        try {
+            new OnlineLogisticRegression()
+                    .setFeaturesCol(FEATURE_COL)
+                    .setInitialModelData(initDenseModel)
+                    .setGlobalBatchSize(2)
+                    .setLabelCol(LABEL_COL)
+                    .fit(onlinePredictTable);
+            Assert.fail("Expected IllegalStateException");
+        } catch (Exception e) {
+            Throwable exception = e;
+            while (exception.getCause() != null) {
+                exception = exception.getCause();
+            }
+            Assert.assertEquals(IllegalStateException.class, exception.getClass());
+            Assert.assertEquals(
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.",
+                    exception.getMessage());
+        }
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        Table onlineTrainTable = getTable(2, 2000, TRAIN_DENSE_ROWS, 2, false);
+        Table onlinePredictTable = getTable(2, 3000, PREDICT_DENSE_ROWS, 2, false);
+        OnlineLogisticRegression onlineLogisticRegression =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL);
+        OnlineLogisticRegression loadedOnlineLogisticRegression =
+                StageTestUtils.saveAndReload(
+                        tEnv, onlineLogisticRegression, tempFolder.newFolder().getAbsolutePath());
+        OnlineLogisticRegressionModel onlineLogisticRegressionModel =
+                loadedOnlineLogisticRegression.fit(onlineTrainTable);
+        Table resultTable =
+                onlineLogisticRegressionModel.setPredictionCol(PREDICT_COL)
+                        .transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable, 4);
+        OnlineLogisticRegressionModel loadedOnlineLogisticRegressionModel =
+                StageTestUtils.saveAndReload(
+                        tEnv,
+                        onlineLogisticRegressionModel,
+                        tempFolder.newFolder().getAbsolutePath());
+        Table resultTableWithLoadedModel =
+                loadedOnlineLogisticRegressionModel.setPredictionCol(PREDICT_COL)
+                        .transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTableWithLoadedModel, -1);
+    }
+
+    @Test
+    public void testGetModelData() throws Exception {
+        Table onlineTrainTable = getTable(2, 2000, TRAIN_DENSE_ROWS, 1, false);
+        List<DenseVector> expected =
+                new ArrayList<>(
+                        Arrays.asList(
+                                Vectors.dense(0.6094293839451556, -0.3535110997464949),
+                                Vectors.dense(0.8817781161262602, -0.6045148530476719),
+                                Vectors.dense(1.0802504223028735, -0.7809336961447708),
+                                Vectors.dense(1.236292181150552, -0.9166121469926248)));
+        OnlineLogisticRegression onlineLogisticRegression =
+                new OnlineLogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initDenseModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL);
+        OnlineLogisticRegressionModel onlineLogisticRegressionModel =
+                onlineLogisticRegression.fit(onlineTrainTable);
+        Table modelData = onlineLogisticRegressionModel.getModelData()[0];
+        DataStream<DenseVector> dataStream =
+                tEnv.toDataStream(modelData)
+                        .map((MapFunction<Row, DenseVector>) value -> value.getFieldAs(0));
+        List<DenseVector> result = IteratorUtils.toList(dataStream.executeAndCollect());
+        result.sort((o1, o2) -> MinMaxScalerTest.compare(o1, o2));
+        Assert.assertEquals(expected, result);
+    }
+
+    @Test
+    public void testSetModelData() throws Exception {
+        Table onlinePredictTable = getTable(1, 1, TRAIN_DENSE_ROWS, 1, false);
+
+        OnlineLogisticRegressionModel onlineLogisticRegressionModel =
+                new OnlineLogisticRegressionModel().setFeaturesCol(FEATURE_COL);
+
+        Table modelData =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] {0.5, 0.1}), 0L)));
+        onlineLogisticRegressionModel.setModelData(modelData);
+
+        Row expected =
+                Row.of(
+                        Vectors.dense(0.1, 2.0),
+                        0.0,
+                        1.0,
+                        Vectors.dense(0.43782349911420193, 0.5621765008857981),
+                        0L);
+        DataStream<Row> results =
+                tEnv.toDataStream(onlineLogisticRegressionModel.transform(onlinePredictTable)[0]);
+        List<Row> resultList = IteratorUtils.toList(results.executeAndCollect());
+        Assert.assertEquals(expected, resultList.get(0));
+    }
+
+    private static void verifyPredictionResult(Table output, int expectedNum) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+        DataStream<Row> stream = tEnv.toDataStream(output);
+        List<Row> result = IteratorUtils.toList(stream.executeAndCollect());
+        Map<Long, Tuple2<Double, Double>> correctRatio = new HashMap<>();
+
+        for (Row row : result) {
+            long modelVersion = row.getFieldAs(MODEL_VERSION_COL);
+            Double pred = row.getFieldAs(PREDICT_COL);
+            Double label = row.getFieldAs(LABEL_COL);
+            if (correctRatio.containsKey(modelVersion)) {
+                Tuple2<Double, Double> t2 = correctRatio.get(modelVersion);
+                if (pred.equals(label)) {
+                    t2.f0 += 1.0;
+                }
+                t2.f1 += 1.0;
+            } else {
+                correctRatio.put(modelVersion, Tuple2.of(pred.equals(label) ? 1.0 : 0.0, 1.0));
+            }
+        }
+        int numModel = 0;
+        for (Long id : correctRatio.keySet()) {
+            assertEquals(1.0, correctRatio.get(id).f0 / correctRatio.get(id).f1, 1.0e-5);
+            numModel++;
+        }
+        if (expectedNum != -1) {
+            assertEquals(expectedNum, numModel);
+        }
+    }
+
+    /** Random selects samples with fixed size from a given sample list. */

Review Comment:
   Could we rephrase the Java doc so that it describes the behavior related to `timeInterval`?
   
   And is it correct to say `random` in this method? It appears that this method deterministically emits the `numSample` values from the given list.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r854906971


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegression.java:
##########
@@ -391,7 +391,8 @@ public void onIterationTerminated(Context context, Collector<double[]> collector
             feedbackBufferState.clear();
             if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
                 updateModel();
-                context.output(modelDataOutputTag, new LogisticRegressionModelData(coefficient));
+                context.output(
+                        modelDataOutputTag, new LogisticRegressionModelData(coefficient, 0L));

Review Comment:
   This is not conflict with barrierModelData. This just add a version for every model produced by online algo or offline algo. then, the offline and online algo can share the same model data.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r884764718


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasWeightCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasReg<T>,
+                HasElasticNet<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =
+            new DoubleParam("alpha", "The parameter alpha of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    Param<Double> BETA =
+            new DoubleParam("alpha", "The parameter beta of ftrl.", 0.1, ParamValidators.gt(0.0));

Review Comment:
   alpha -> beta.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java:
##########
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel.predictOneDataPoint;
+
+/**
+ * A Model which classifies data using the model data computed by {@link OnlineLogisticRegression}.
+ */
+public class OnlineLogisticRegressionModel
+        implements Model<OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionModelParams<OnlineLogisticRegressionModel> {
+    public static final String MODEL_DATA_VERSION_GAUGE_KEY = "modelDataVersion";
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineLogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                Types.DOUBLE,
+                                TypeInformation.of(DenseVector.class),
+                                Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol(),
+                                getModelVersionCol()));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(
+                                LogisticRegressionModelData.getModelDataStream(modelDataTable)
+                                        .broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, LogisticRegressionModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0L;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            MODEL_DATA_VERSION_GAUGE_KEY,
+                            (Gauge<String>) () -> Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws Exception {
+            processElement(streamRecord);
+        }
+
+        @Override
+        public void processElement2(StreamRecord<LogisticRegressionModelData> streamRecord)
+                throws Exception {
+            LogisticRegressionModelData modelData = streamRecord.getValue();
+            coefficient = modelData.coefficient;
+            modelDataVersion = modelData.modelVersion;
+            for (Row dataPoint : bufferedPointsState.get()) {
+                processElement(new StreamRecord<>(dataPoint));
+            }
+            bufferedPointsState.clear();
+        }
+
+        public void processElement(StreamRecord<Row> streamRecord) throws Exception {
+            Row dataPoint = streamRecord.getValue();
+            if (coefficient == null) {
+                bufferedPointsState.add(dataPoint);
+                return;
+            }
+            Vector features = (Vector) dataPoint.getField(featuresCol);
+            Row predictionResult = predictOneDataPoint(features, coefficient);
+            output.collect(
+                    new StreamRecord<>(
+                            Row.join(
+                                    dataPoint,
+                                    Row.of(
+                                            predictionResult.getField(0),
+                                            predictionResult.getField(1),
+                                            modelDataVersion))));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineLogisticRegressionModel load(StreamTableEnvironment tEnv, String path)
+            throws IOException {
+        OnlineLogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =

Review Comment:
   Given that `OnlineLogisticRegressionModel::save(...)` does not write model data to disk, it seems inconsistent to have `OnlineLogisticRegressionModel::load(...)` read model data from the given path.
   
   Should we follow the same behavior of `OnlineKMeansModel::load(...)` and only read metadata in this mehod?



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java:
##########
@@ -0,0 +1,687 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.JobSubmissionResult;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import static org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel.MODEL_DATA_VERSION_GAUGE_KEY;
+
+/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */
+public class OnlineLogisticRegressionTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final double[] ONE_ARRAY = new double[] {1.0, 1.0, 1.0};
+
+    private static final Row[] TRAIN_DENSE_ROWS_1 =
+            new Row[] {
+                Row.of(Vectors.dense(0.1, 2.), 0.),
+                Row.of(Vectors.dense(0.2, 2.), 0.),
+                Row.of(Vectors.dense(0.3, 2.), 0.),
+                Row.of(Vectors.dense(0.4, 2.), 0.),
+                Row.of(Vectors.dense(0.5, 2.), 0.),
+                Row.of(Vectors.dense(11., 12.), 1.),
+                Row.of(Vectors.dense(12., 11.), 1.),
+                Row.of(Vectors.dense(13., 12.), 1.),
+                Row.of(Vectors.dense(14., 12.), 1.),
+                Row.of(Vectors.dense(15., 12.), 1.)
+            };
+
+    private static final Row[] TRAIN_DENSE_ROWS_2 =
+            new Row[] {
+                Row.of(Vectors.dense(0.2, 3.), 0.),
+                Row.of(Vectors.dense(0.8, 1.), 0.),
+                Row.of(Vectors.dense(0.7, 1.), 0.),
+                Row.of(Vectors.dense(0.6, 2.), 0.),
+                Row.of(Vectors.dense(0.2, 2.), 0.),
+                Row.of(Vectors.dense(14., 17.), 1.),
+                Row.of(Vectors.dense(15., 10.), 1.),
+                Row.of(Vectors.dense(16., 16.), 1.),
+                Row.of(Vectors.dense(17., 10.), 1.),
+                Row.of(Vectors.dense(18., 13.), 1.)
+            };
+
+    private static final Row[] PREDICT_DENSE_ROWS =
+            new Row[] {
+                Row.of(Vectors.dense(0.8, 2.7), 0.0), Row.of(Vectors.dense(15.5, 11.2), 1.0)
+            };
+
+    private static final Row[] TRAIN_SPARSE_ROWS_1 =
+            new Row[] {
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0., 1.0),
+                Row.of(Vectors.sparse(10, new int[] {0, 2, 3}, ONE_ARRAY), 0., 1.4),
+                Row.of(Vectors.sparse(10, new int[] {0, 3, 4}, ONE_ARRAY), 0., 1.3),
+                Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0., 1.4),
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0., 1.6),
+                Row.of(Vectors.sparse(10, new int[] {6, 7, 8}, ONE_ARRAY), 1., 1.8),
+                Row.of(Vectors.sparse(10, new int[] {6, 8, 9}, ONE_ARRAY), 1., 1.9),
+                Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1., 1.0),
+                Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1., 1.1)
+            };
+
+    private static final Row[] TRAIN_SPARSE_ROWS_2 =
+            new Row[] {
+                Row.of(Vectors.sparse(10, new int[] {1, 2, 4}, ONE_ARRAY), 0., 1.0),
+                Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0., 1.3),
+                Row.of(Vectors.sparse(10, new int[] {0, 2, 4}, ONE_ARRAY), 0., 1.4),
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0., 1.0),
+                Row.of(Vectors.sparse(10, new int[] {6, 7, 9}, ONE_ARRAY), 1., 1.6),
+                Row.of(Vectors.sparse(10, new int[] {7, 8, 9}, ONE_ARRAY), 1., 1.8),
+                Row.of(Vectors.sparse(10, new int[] {5, 7, 9}, ONE_ARRAY), 1., 1.0),
+                Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1., 1.5),
+                Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1., 1.0)
+            };
+
+    private static final Row[] PREDICT_SPARSE_ROWS =
+            new Row[] {
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 5}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.)
+            };
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private long currentModelDataVersion;
+
+    private InMemorySourceFunction<Row> trainDenseSource;
+    private InMemorySourceFunction<Row> predictDenseSource;
+    private InMemorySourceFunction<Row> trainSparseSource;
+    private InMemorySourceFunction<Row> predictSparseSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<LogisticRegressionModelData> modelDataSink;
+
+    // TODO: creates static mini cluster once for whole test class after dependency upgrades to
+    // Flink 1.15.
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainDenseTable;
+    private Table onlineTrainDenseTable;
+    private Table onlinePredictDenseTable;
+    private Table onlineTrainSparseTable;
+    private Table onlinePredictSparseTable;
+    private Table initDenseModel;
+    private Table initSparseModel;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainDenseSource = new InMemorySourceFunction<>();
+        predictDenseSource = new InMemorySourceFunction<>();
+        trainSparseSource = new InMemorySourceFunction<>();
+        predictSparseSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        reporter = InMemoryReporter.create();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                .setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        offlineTrainDenseTable =
+                tEnv.fromDataStream(env.fromElements(TRAIN_DENSE_ROWS_1)).as("features", "label");
+        onlineTrainDenseTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                trainDenseSource,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label"})));
+
+        onlinePredictDenseTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                predictDenseSource,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label"})));
+
+        onlineTrainSparseTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                trainSparseSource,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(SparseVector.class),
+                                            Types.DOUBLE,
+                                            Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label", "weight"})));
+
+        onlinePredictSparseTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                predictSparseSource,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(SparseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label"})));
+
+        initDenseModel =
+                tEnv.fromDataStream(
+                        env.fromElements(
+                                Row.of(
+                                        new DenseVector(
+                                                new double[] {
+                                                    0.41233679404769874, -0.18088118293232122
+                                                }),
+                                        0L)));
+        initSparseModel =
+                tEnv.fromDataStream(
+                        env.fromElements(
+                                Row.of(
+                                        new DenseVector(
+                                                new double[] {
+                                                    0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01,
+                                                    0.01, 0.01
+                                                }),
+                                        0L)));
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds sinks for
+     * OnlineLogisticRegressionModel's transform output and model data.
+     */
+    private void transformAndOutputData(
+            OnlineLogisticRegressionModel onlineModel, boolean isSparse) {
+        Table outputTable =
+                onlineModel
+                        .transform(isSparse ? onlinePredictSparseTable : onlinePredictDenseTable)[
+                        0];
+        tEnv.toDataStream(outputTable).addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        LogisticRegressionModelData.getModelDataStream(modelDataTable).addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup(JobID jobID) throws InterruptedException {
+        while (reporter.findMetrics(jobID, MODEL_DATA_VERSION_GAUGE_KEY).size()
+                < defaultParallelism) {
+            Thread.sleep(100);
+        }
+        waitModelDataUpdate(jobID);
+    }
+
+    /** Blocks the thread until the Model has received the next model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate(JobID jobID) throws InterruptedException {
+        do {
+            long tmpModelDataVersion =
+                    reporter.findMetrics(jobID, MODEL_DATA_VERSION_GAUGE_KEY).values().stream()
+                            .map(x -> Long.parseLong(((Gauge<String>) x).getValue()))
+                            .min(Long::compareTo)
+                            .get();
+            if (tmpModelDataVersion == currentModelDataVersion) {
+                Thread.sleep(100);
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedRawInfo A list containing sets of expected result RawInfo.
+     */
+    private void predictAndAssert(List<DenseVector> expectedRawInfo, boolean isSparse)
+            throws Exception {
+        if (isSparse) {
+            predictSparseSource.addAll(PREDICT_SPARSE_ROWS);
+        } else {
+            predictDenseSource.addAll(PREDICT_DENSE_ROWS);
+        }
+        List<Row> rawResult =
+                outputSink.poll(isSparse ? PREDICT_SPARSE_ROWS.length : PREDICT_DENSE_ROWS.length);
+        List<DenseVector> resultDetail = new ArrayList<>(rawResult.size());
+        for (Row row : rawResult) {
+            resultDetail.add(row.getFieldAs(3));
+        }
+        resultDetail.sort(TestUtils::compare);
+        expectedRawInfo.sort(TestUtils::compare);
+        for (int i = 0; i < resultDetail.size(); ++i) {
+            double[] realData = resultDetail.get(i).values;
+            double[] expectedData = expectedRawInfo.get(i).values;
+            for (int j = 0; j < expectedData.length; ++j) {
+                Assert.assertEquals(realData[j], expectedData[j], 1.0e-5);
+            }
+        }
+    }
+
+    private JobID submitJob(JobGraph jobGraph)
+            throws ExecutionException, InterruptedException, TimeoutException {
+        return miniCluster
+                .submitJob(jobGraph)
+                .thenApply(JobSubmissionResult::getJobID)
+                .get(1, TimeUnit.SECONDS);
+    }
+
+    @Test
+    public void testParam() {
+        OnlineLogisticRegression onlineLogisticRegression = new OnlineLogisticRegression();
+        Assert.assertEquals("features", onlineLogisticRegression.getFeaturesCol());
+        Assert.assertEquals("count", onlineLogisticRegression.getBatchStrategy());
+        Assert.assertEquals("label", onlineLogisticRegression.getLabelCol());
+        Assert.assertEquals(0.0, onlineLogisticRegression.getReg(), 1.0e-5);
+        Assert.assertEquals(0.0, onlineLogisticRegression.getElasticNet(), 1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getAlpha(), 1.0e-5);
+        Assert.assertEquals(0.1, onlineLogisticRegression.getBeta(), 1.0e-5);
+        Assert.assertEquals(32, onlineLogisticRegression.getGlobalBatchSize());
+
+        onlineLogisticRegression
+                .setFeaturesCol("test_feature")
+                .setLabelCol("test_label")
+                .setGlobalBatchSize(5)
+                .setReg(0.25)
+                .setElasticNet(0.25)

Review Comment:
   Should we use different values for these parameters so that we can capture the bugs mentioned in the other comment?
   
   Same for the python tests.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasWeightCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasReg<T>,
+                HasElasticNet<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =
+            new DoubleParam("alpha", "The parameter alpha of ftrl.", 0.1, ParamValidators.gt(0.0));

Review Comment:
   nits: Would it be more intuitive to say `The alpha parameter ...`?
   
   Same for `The beta parameter ...`.
   
   Same for the python parameter description.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r884801285


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);

Review Comment:
   modelVersion has been write to state. If restarting happened, algorithm will read the model version from checkpoint. Just as your example, if restarting happened at record 301, the model  data (without model version) is read from checkpoint in CalculateLocalGradient() and the modelVersion (modelVersion=3) is read from checkpoint in the CreateLrModelData() before sink model out. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r881176285


##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java:
##########
@@ -71,6 +71,43 @@ public static double dot(DenseVector x, DenseVector y) {
         return JAVA_BLAS.ddot(x.size(), x.values, 1, y.values, 1);
     }
 
+    public static double dot(DenseVector x, SparseVector y) {
+        Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched.");
+        double dotValue = 0.0;
+        for (int i = 0; i < y.indices.length; ++i) {
+            dotValue += y.values[i] * x.values[y.indices[i]];
+        }
+        return dotValue;
+    }
+
+    public static double dot(SparseVector x, DenseVector y) {

Review Comment:
   nit: 
   ```java
       public static double dot(SparseVector x, DenseVector y) {
           return dot(y, x);
       }
   ```



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasReg<T>,
+                HasElasticNet<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =
+            new DoubleParam("alpha", "The parameter alpha of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getAlpha() {
+        return get(ALPHA);
+    }
+
+    default T setAlpha(Double value) {
+        return set(ALPHA, value);
+    }
+
+    Param<Double> BETA =
+            new DoubleParam("alpha", "The parameter beta of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getBeta() {
+        return get(BETA);
+    }
+
+    default T setBeta(Double value) {
+        return set(BETA, value);
+    }
+
+    Param<Integer> MODEL_SAVE_INTERVAL =
+            new IntParam(
+                    "modelSaveInterval",
+                    "The iteration steps between two output models.",

Review Comment:
   Does this parameter mean OLR should generate a new version of model data after every 100 batches of train data? The term "interval" gives me an impression that it is related to times. Shall we modify the parameter name and description to better reflect its semantics?
   
   Besides, "model data" should be more proper than "models".



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java:
##########
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.ml.linalg.DenseVector;
+
+/** Utility methods for testing. */
+public class TestUtils {
+    /** Note: this comparator imposes orderings that are inconsistent with equals. */
+    public static int compare(DenseVector first, DenseVector second) {
+        for (int i = 0; i < first.size(); i++) {

Review Comment:
   nit: we should check for size equality before doing the comparison.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel.predictOneDataPoint;
+
+/**
+ * A Model which classifies data using the model data computed by {@link OnlineLogisticRegression}.
+ */
+public class OnlineLogisticRegressionModel
+        implements Model<OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionModelParams<OnlineLogisticRegressionModel> {
+    public static final String MODEL_DATA_VERSION_GAUGE_KEY = "modelDataVersion";
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineLogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                Types.DOUBLE,
+                                TypeInformation.of(DenseVector.class),
+                                Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol(),
+                                getModelVersionCol()));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(
+                                LogisticRegressionModelData.getModelDataStream(modelDataTable)
+                                        .broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, LogisticRegressionModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0L;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            MODEL_DATA_VERSION_GAUGE_KEY,
+                            (Gauge<String>) () -> Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws Exception {
+            Row dataPoint = streamRecord.getValue();
+            if (coefficient == null) {
+                bufferedPointsState.add(dataPoint);
+                return;
+            }
+            Vector features = (Vector) dataPoint.getField(featuresCol);
+            Row predictionResult = predictOneDataPoint(features, coefficient);
+            output.collect(
+                    new StreamRecord<>(
+                            Row.join(
+                                    dataPoint,
+                                    Row.of(
+                                            predictionResult.getField(0),
+                                            predictionResult.getField(1),
+                                            modelDataVersion))));
+        }
+
+        @Override
+        public void processElement2(StreamRecord<LogisticRegressionModelData> streamRecord)
+                throws Exception {
+            LogisticRegressionModelData modelData = streamRecord.getValue();
+            coefficient = modelData.coefficient;
+            modelDataVersion = modelData.modelVersion;
+            for (Row dataPoint : bufferedPointsState.get()) {
+                processElement1(new StreamRecord<>(dataPoint));
+            }
+            bufferedPointsState.clear();
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                LogisticRegressionModelData.getModelDataStream(modelDataTable),
+                path,
+                new LogisticRegressionModelData.ModelDataEncoder());
+    }
+
+    public static OnlineLogisticRegressionModel load(StreamTableEnvironment tEnv, String path)
+            throws IOException {
+        OnlineLogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(
+                        tEnv, path, new LogisticRegressionModelData.ModelDataDecoder());
+        return model.setModelData(modelDataTable);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public OnlineLogisticRegressionModel setModelData(Table... inputs) {
+        modelDataTable = inputs[0];

Review Comment:
   nit: Let's check `inputs.length` in this method.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);

Review Comment:
   Should we also extract `modelVersion` from the input model data, and increase modelVersion on the basis of the extracted value?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBeta(),
+                        getReg(),
+                        getElasticNet(),
+                        getModelSaveInterval());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private long modelVersion = 1L;
+        private final int modelSaveInterval;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double reg,
+                double elasticNet,
+                int modelSaveInterval) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+            this.modelSaveInterval = modelSaveInterval;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .filter(
+                                    new FilterFunction<DenseVector>() {
+                                        private int step = 0;
+
+                                        @Override
+                                        public boolean filter(DenseVector denseVector) {
+                                            step++;
+                                            return step % modelSaveInterval == 0;
+                                        }
+                                    })
+                            .setParallelism(1)
+                            .map(
+                                    (MapFunction<DenseVector, LogisticRegressionModelData>)
+                                            value ->
+                                                    new LogisticRegressionModelData(
+                                                            value, modelVersion++))
+                            .setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData));
+        }
+    }
+
+    /** Updates model. */
+    private static class UpdateModel extends AbstractStreamOperator<DenseVector>
+            implements OneInputStreamOperator<DenseVector[], DenseVector> {
+        private ListState<double[]> nParamState;
+        private ListState<double[]> zParamState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private double[] nParam;
+        private double[] zParam;
+
+        public UpdateModel(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            nParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("nParamState", double[].class));
+            zParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("zParamState", double[].class));
+        }
+
+        @Override
+        public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            DenseVector[] gradientInfo = streamRecord.getValue();
+            double[] coefficient = gradientInfo[2].values;
+            double[] g = gradientInfo[0].values;
+            for (int i = 0; i < g.length; ++i) {
+                if (gradientInfo[1].values[i] != 0.0) {
+                    g[i] = g[i] / gradientInfo[1].values[i];
+                }
+            }
+            if (zParam == null) {
+                zParam = new double[g.length];
+                nParam = new double[g.length];
+                nParamState.add(nParam);
+                zParamState.add(zParam);
+            }
+
+            for (int i = 0; i < zParam.length; ++i) {
+                double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - Math.sqrt(nParam[i])) / alpha;
+                zParam[i] += g[i] - sigma * coefficient[i];
+                nParam[i] += g[i] * g[i];
+
+                if (Math.abs(zParam[i]) <= l1) {
+                    coefficient[i] = 0.0;
+                } else {
+                    coefficient[i] =
+                            ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i])
+                                    / ((beta + Math.sqrt(nParam[i])) / alpha + l2);
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector(coefficient)));
+        }
+    }
+
+    private static class CalculateLocalGradient extends AbstractStreamOperator<DenseVector[]>

Review Comment:
   Is it possible for us to reuse codes in `org.apache.flink.ml.common.lossfunc` and `org.apache.flink.ml.common.optimizer` for the implementation of this class? Or shall we make FTRL a common optimizer and place it in `org.apache.flink.ml.common`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBeta(),
+                        getReg(),
+                        getElasticNet(),
+                        getModelSaveInterval());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private long modelVersion = 1L;
+        private final int modelSaveInterval;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double reg,
+                double elasticNet,
+                int modelSaveInterval) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+            this.modelSaveInterval = modelSaveInterval;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .filter(
+                                    new FilterFunction<DenseVector>() {
+                                        private int step = 0;
+
+                                        @Override
+                                        public boolean filter(DenseVector denseVector) {
+                                            step++;
+                                            return step % modelSaveInterval == 0;
+                                        }
+                                    })
+                            .setParallelism(1)
+                            .map(
+                                    (MapFunction<DenseVector, LogisticRegressionModelData>)
+                                            value ->
+                                                    new LogisticRegressionModelData(
+                                                            value, modelVersion++))
+                            .setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData));
+        }
+    }
+
+    /** Updates model. */
+    private static class UpdateModel extends AbstractStreamOperator<DenseVector>
+            implements OneInputStreamOperator<DenseVector[], DenseVector> {
+        private ListState<double[]> nParamState;
+        private ListState<double[]> zParamState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private double[] nParam;
+        private double[] zParam;
+
+        public UpdateModel(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            nParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("nParamState", double[].class));
+            zParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("zParamState", double[].class));
+        }
+
+        @Override
+        public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            DenseVector[] gradientInfo = streamRecord.getValue();
+            double[] coefficient = gradientInfo[2].values;
+            double[] g = gradientInfo[0].values;
+            for (int i = 0; i < g.length; ++i) {
+                if (gradientInfo[1].values[i] != 0.0) {
+                    g[i] = g[i] / gradientInfo[1].values[i];
+                }
+            }
+            if (zParam == null) {
+                zParam = new double[g.length];
+                nParam = new double[g.length];
+                nParamState.add(nParam);
+                zParamState.add(zParam);
+            }
+
+            for (int i = 0; i < zParam.length; ++i) {
+                double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - Math.sqrt(nParam[i])) / alpha;
+                zParam[i] += g[i] - sigma * coefficient[i];
+                nParam[i] += g[i] * g[i];
+
+                if (Math.abs(zParam[i]) <= l1) {
+                    coefficient[i] = 0.0;
+                } else {
+                    coefficient[i] =
+                            ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i])
+                                    / ((beta + Math.sqrt(nParam[i])) / alpha + l2);
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector(coefficient)));
+        }
+    }
+
+    private static class CalculateLocalGradient extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<Row[], DenseVector, DenseVector[]> {
+        private ListState<DenseVector> modelDataState;
+        private ListState<Row[]> localBatchDataState;
+        private double[] gradient;
+        private double[] weight;
+        private int[] denseVectorIndices;
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector.class));
+            TypeInformation<Row[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class));
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row[]> pointsRecord) throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            calculateGradient();
+        }
+
+        private void calculateGradient() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+            DenseVector modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Row[]> pointsList = IteratorUtils.toList(localBatchDataState.get().iterator());
+            Row[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Row point : points) {
+                Vector vec = point.getFieldAs(0);
+                double label = point.getFieldAs(1);
+                if (gradient == null) {
+                    gradient = new double[vec.size()];
+                    weight = new double[gradient.length];
+                    if (vec instanceof DenseVector) {
+                        denseVectorIndices = new int[vec.size()];
+                        for (int i = 0; i < denseVectorIndices.length; ++i) {
+                            denseVectorIndices[i] = i;
+                        }
+                    }
+                }
+
+                int[] indices;
+                double[] values;
+                if (vec instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) vec;
+                    indices = denseVectorIndices;
+                    values = denseVector.values;
+                } else {
+                    SparseVector sparseVector = (SparseVector) vec;
+                    indices = sparseVector.indices;
+                    values = sparseVector.values;
+                }
+                double p = 0.0;
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    p += modelData.values[idx] * values[i];
+                }
+                p = 1 / (1 + Math.exp(-p));
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    gradient[idx] += (p - label) * values[i];
+                    weight[idx] += 1.0;
+                }
+            }
+
+            if (points.length > 0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new DenseVector[] {
+                                    new DenseVector(gradient),
+                                    new DenseVector(weight),
+                                    (getRuntimeContext().getIndexOfThisSubtask() == 0)
+                                            ? modelData
+                                            : null
+                                }));
+            }
+            Arrays.fill(gradient, 0.0);
+            Arrays.fill(weight, 0.0);
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector> modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            calculateGradient();
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(

Review Comment:
   In `OnlineKMeans`'s PR it has been discussed that a stage should not attempt to save and load unbounded stream. If an online algorithm is to be reloaded, the model data stream should be externally set instead of being loaded.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/OnlineLogisticRegressionTest.java:
##########
@@ -0,0 +1,661 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegressionModelData;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegression;
+import org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.apache.flink.ml.classification.logisticregression.OnlineLogisticRegressionModel.MODEL_DATA_VERSION_GAUGE_KEY;
+
+/** Tests {@link OnlineLogisticRegression} and {@link OnlineLogisticRegressionModel}. */
+public class OnlineLogisticRegressionTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final double[] ONE_ARRAY = new double[] {1.0, 1.0, 1.0};
+
+    private static final Row[] TRAIN_DENSE_ROWS_1 =
+            new Row[] {
+                Row.of(Vectors.dense(0.1, 2.), 0.),
+                Row.of(Vectors.dense(0.2, 2.), 0.),
+                Row.of(Vectors.dense(0.3, 2.), 0.),
+                Row.of(Vectors.dense(0.4, 2.), 0.),
+                Row.of(Vectors.dense(0.5, 2.), 0.),
+                Row.of(Vectors.dense(11., 12.), 1.),
+                Row.of(Vectors.dense(12., 11.), 1.),
+                Row.of(Vectors.dense(13., 12.), 1.),
+                Row.of(Vectors.dense(14., 12.), 1.),
+                Row.of(Vectors.dense(15., 12.), 1.)
+            };
+
+    private static final Row[] TRAIN_DENSE_ROWS_2 =
+            new Row[] {
+                Row.of(Vectors.dense(0.2, 3.), 0.),
+                Row.of(Vectors.dense(0.8, 1.), 0.),
+                Row.of(Vectors.dense(0.7, 1.), 0.),
+                Row.of(Vectors.dense(0.6, 2.), 0.),
+                Row.of(Vectors.dense(0.2, 2.), 0.),
+                Row.of(Vectors.dense(14., 17.), 1.),
+                Row.of(Vectors.dense(15., 10.), 1.),
+                Row.of(Vectors.dense(16., 16.), 1.),
+                Row.of(Vectors.dense(17., 10.), 1.),
+                Row.of(Vectors.dense(18., 13.), 1.)
+            };
+
+    private static final Row[] PREDICT_DENSE_ROWS =
+            new Row[] {
+                Row.of(Vectors.dense(0.8, 2.7), 0.0), Row.of(Vectors.dense(15.5, 11.2), 1.0)
+            };
+
+    private static final Row[] TRAIN_SPARSE_ROWS_1 =
+            new Row[] {
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {0, 2, 3}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {0, 3, 4}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {2, 3, 4}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {1, 3, 4}, ONE_ARRAY), 0.),
+                Row.of(Vectors.sparse(10, new int[] {6, 7, 8}, ONE_ARRAY), 1.),
+                Row.of(Vectors.sparse(10, new int[] {6, 8, 9}, ONE_ARRAY), 1.),
+                Row.of(Vectors.sparse(10, new int[] {5, 8, 9}, ONE_ARRAY), 1.),
+                Row.of(Vectors.sparse(10, new int[] {5, 6, 7}, ONE_ARRAY), 1.)

Review Comment:
   Shall we use the same train/predict data for dense and sparse vector? That should simplify the code structure and make it clearer that the algorithm can produce the same results no matter which data type input data is.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBeta(),
+                        getReg(),
+                        getElasticNet(),
+                        getModelSaveInterval());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private long modelVersion = 1L;
+        private final int modelSaveInterval;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double reg,
+                double elasticNet,
+                int modelSaveInterval) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+            this.modelSaveInterval = modelSaveInterval;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .filter(
+                                    new FilterFunction<DenseVector>() {
+                                        private int step = 0;
+
+                                        @Override
+                                        public boolean filter(DenseVector denseVector) {
+                                            step++;
+                                            return step % modelSaveInterval == 0;
+                                        }
+                                    })
+                            .setParallelism(1)
+                            .map(
+                                    (MapFunction<DenseVector, LogisticRegressionModelData>)
+                                            value ->
+                                                    new LogisticRegressionModelData(
+                                                            value, modelVersion++))

Review Comment:
   Some values, like `modelVersion` or `step`, should be checkpoint and restored from snapshot during failover, while the current implementation has not achieved this. Could you please check all variables and see if they need to be stored in state?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBeta(),
+                        getReg(),
+                        getElasticNet(),
+                        getModelSaveInterval());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private long modelVersion = 1L;
+        private final int modelSaveInterval;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double reg,
+                double elasticNet,
+                int modelSaveInterval) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+            this.modelSaveInterval = modelSaveInterval;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .filter(
+                                    new FilterFunction<DenseVector>() {
+                                        private int step = 0;
+
+                                        @Override
+                                        public boolean filter(DenseVector denseVector) {
+                                            step++;
+                                            return step % modelSaveInterval == 0;
+                                        }
+                                    })
+                            .setParallelism(1)

Review Comment:
   Do you think we should extract this part of code as infrastructure, like `org.apache.flink.ml.common.iteration.ForwardInputsOfLastRound`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")

Review Comment:
   nit: this could be removed.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java:
##########
@@ -161,8 +163,12 @@ public Row map(Row dataPoint) {
      * @param coefficient The model parameters.
      * @return The prediction label and the raw probabilities.
      */
-    private static Row predictOneDataPoint(DenseVector feature, DenseVector coefficient) {
-        double dotValue = BLAS.dot(feature, coefficient);
+    public static Row predictOneDataPoint(Vector feature, DenseVector coefficient) {
+
+        double dotValue =
+                feature instanceof SparseVector
+                        ? BLAS.dot((SparseVector) feature, coefficient)
+                        : BLAS.dot((DenseVector) feature, coefficient);

Review Comment:
   nit: It might be better to add a `BLAS.dot(Vector, Vector)` method and make this type conversion in that method.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionModel.java:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.flink.ml.classification.logisticregression.LogisticRegressionModel.predictOneDataPoint;
+
+/**
+ * A Model which classifies data using the model data computed by {@link OnlineLogisticRegression}.
+ */
+public class OnlineLogisticRegressionModel
+        implements Model<OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionModelParams<OnlineLogisticRegressionModel> {
+    public static final String MODEL_DATA_VERSION_GAUGE_KEY = "modelDataVersion";
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineLogisticRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                Types.DOUBLE,
+                                TypeInformation.of(DenseVector.class),
+                                Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(),
+                                getPredictionCol(),
+                                getRawPredictionCol(),
+                                getModelVersionCol()));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        .connect(
+                                LogisticRegressionModelData.getModelDataStream(modelDataTable)
+                                        .broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, LogisticRegressionModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0L;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            MODEL_DATA_VERSION_GAUGE_KEY,
+                            (Gauge<String>) () -> Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws Exception {
+            Row dataPoint = streamRecord.getValue();
+            if (coefficient == null) {
+                bufferedPointsState.add(dataPoint);
+                return;
+            }
+            Vector features = (Vector) dataPoint.getField(featuresCol);
+            Row predictionResult = predictOneDataPoint(features, coefficient);
+            output.collect(
+                    new StreamRecord<>(
+                            Row.join(
+                                    dataPoint,
+                                    Row.of(
+                                            predictionResult.getField(0),
+                                            predictionResult.getField(1),
+                                            modelDataVersion))));
+        }
+
+        @Override
+        public void processElement2(StreamRecord<LogisticRegressionModelData> streamRecord)
+                throws Exception {
+            LogisticRegressionModelData modelData = streamRecord.getValue();
+            coefficient = modelData.coefficient;
+            modelDataVersion = modelData.modelVersion;
+            for (Row dataPoint : bufferedPointsState.get()) {
+                processElement1(new StreamRecord<>(dataPoint));

Review Comment:
   nit: Shall we create some method like `processElement(Row)` and reuse this method in processElement1 and processElement2? Wrapping the value in StreamRecord and soon unwrap it in processElement1 seems not elegant.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r868872809


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasL2.java:
##########
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared L2 param. */
+public interface HasL2<T> extends WithParams<T> {
+    Param<Double> L_2 = new DoubleParam("l2", "The l2 param.", 0.1, ParamValidators.gt(0.0));

Review Comment:
   After some offline discussion, @weibozhao  and I agree that we could reuse the exising `HasReg` and `HasElasticNet` param, without introducing `HasL1` and `HasL2`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r868883335


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils.GlobalBatchCreator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils.GlobalBatchSplitter;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector[]>)
+                                value -> new DenseVector[] {value.coefficient});
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBeta(),
+                        getL1(),
+                        getL2(),
+                        getFeaturesCol(),
+                        getLabelCol());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData),
+                                DataStreamList.of(tEnv.toDataStream(inputs[0])),
+                                body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl algorithm. */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private final String featureCol;
+        private final String labelCol;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double l1,
+                double l2,
+                String featureCol,
+                String labelCol) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+            this.featureCol = featureCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator<>())
+                            .flatMap(new GlobalBatchSplitter<>(parallelism))
+                            .rebalance()

Review Comment:
   This may lead to a case where one downstream task does not receive one record and others receive more than one record. Is this expected?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +191,49 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * An operator that splits a global batch into evenly-sized local batches, and distributes them
+     * to downstream operator.
+     *
+     * @param <T> Data type of batch data.
+     */
+    public static class GlobalBatchSplitter<T> implements FlatMapFunction<T[], T[]> {

Review Comment:
   Is it better to put `GlobalBatchSplitter` and `GlobalBatchCreator` in separate classes? 
   
   Also, could these two functions merged into one?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the OnlineLogisticRegression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Online_machine_learning.
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector[]>)
+                                value -> new DenseVector[] {value.coefficient});
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBETA(),
+                        getL1(),
+                        getL2(),
+                        getFeaturesCol(),
+                        getLabelCol());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData),
+                                DataStreamList.of(tEnv.toDataStream(inputs[0])),
+                                body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl algorithm. */

Review Comment:
   As I understand from Algorithm 1 in paper [1], there is no init model data for FTRL and the init implementation is sequential.
   
   And for the distributed implemenation, the only one I can find is Tensorflow#FTRL [2][3]. In this implementation, we accumulated the gradient from workers and then use gradient to update n and z, then w.
   
   The above process is quite different from the current implementation. Could you explain a bit about the difference here?
   
   [1] https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf
   [2] https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
   [3] https://github.com/keras-team/keras/blob/d8fcb9d4d4dad45080ecfdd575483653028f8eda/keras/optimizer_v2/ftrl.py#L216



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils.GlobalBatchCreator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils.GlobalBatchSplitter;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector[]>)
+                                value -> new DenseVector[] {value.coefficient});
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBeta(),
+                        getL1(),
+                        getL2(),
+                        getFeaturesCol(),
+                        getLabelCol());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData),
+                                DataStreamList.of(tEnv.toDataStream(inputs[0])),
+                                body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl algorithm. */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private final String featureCol;
+        private final String labelCol;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double l1,
+                double l2,
+                String featureCol,
+                String labelCol) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+            this.featureCol = featureCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator<>())
+                            .flatMap(new GlobalBatchSplitter<>(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new ModelDataLocalUpdater(
+                                            alpha, beta, l1, l2, featureCol, labelCol))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+            DataStream<DenseVector[]> feedbackModelData =
+                    newModelData
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> {
+                                                double[] z = value[1].values;
+                                                double[] n = value[2].values;
+                                                for (int i = 0; i < z.length; ++i) {
+                                                    if (Math.abs(z[i]) <= l1) {
+                                                        value[0].values[i] = 0.0;
+                                                    } else {
+                                                        value[0].values[i] =
+                                                                ((z[i] < 0 ? -1 : 1) * l1 - z[i])
+                                                                        / ((beta + Math.sqrt(n[i]))
+                                                                                        / alpha
+                                                                                + l2);
+                                                    }
+                                                }
+                                                return new DenseVector[] {value[0]};
+                                            })
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .map(
+                                    (MapFunction<DenseVector[], LogisticRegressionModelData>)
+                                            value ->
+                                                    new LogisticRegressionModelData(
+                                                            value[0], System.nanoTime()))

Review Comment:
   I guess here you want to use the current timestamp as the model version. Could you follow the existing practice we have aggreed on in `OnlineKmeans` and use incremental counts?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils.GlobalBatchCreator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils.GlobalBatchSplitter;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector[]>)
+                                value -> new DenseVector[] {value.coefficient});
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBeta(),
+                        getL1(),
+                        getL2(),
+                        getFeaturesCol(),
+                        getLabelCol());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData),
+                                DataStreamList.of(tEnv.toDataStream(inputs[0])),
+                                body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl algorithm. */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private final String featureCol;
+        private final String labelCol;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double l1,
+                double l2,
+                String featureCol,
+                String labelCol) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+            this.featureCol = featureCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator<>())
+                            .flatMap(new GlobalBatchSplitter<>(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new ModelDataLocalUpdater(
+                                            alpha, beta, l1, l2, featureCol, labelCol))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+            DataStream<DenseVector[]> feedbackModelData =

Review Comment:
   It seems that `feedbackModelData` contains only one DenseVector. How about we change the type to `DenseVector`?



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java:
##########
@@ -111,7 +101,7 @@ private static void verifyPredictionResult(
                                 (MapFunction<Row, DenseVector>)
                                         row -> (DenseVector) row.getField(outputCol));
         List<DenseVector> result = IteratorUtils.toList(stream.executeAndCollect());
-        result.sort(MinMaxScalerTest::compare);
+        result.sort(TestUtils::compare);

Review Comment:
   nit: what about using `TestBaseUtils.compareResultCollections`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,382 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils.GlobalBatchCreator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils.GlobalBatchSplitter;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData =

Review Comment:
   `initModelData` seems not a `DenseVector[]`, but a `LogisticRegressionModelData`? Could you do the refactor here?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java:
##########
@@ -71,6 +71,21 @@ public static double dot(DenseVector x, DenseVector y) {
         return JAVA_BLAS.ddot(x.size(), x.values, 1, y.values, 1);
     }
 
+    /** x \cdot y . x maybe DenseVector or SparseVector. */
+    public static double dot(Vector x, DenseVector y) {

Review Comment:
   How about we refactor the code following the structure of `hdot` and exposing the following function to users:
   `public static double dot(Vector, Vector)`?
   
   Moreover, unit tests should be added to verify this implementation.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java:
##########
@@ -232,7 +232,7 @@ public void testSaveLoadAndPredict() throws Exception {
         LogisticRegressionModel model = logisticRegression.fit(binomialDataTable);
         model = StageTestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
         assertEquals(
-                Collections.singletonList("coefficient"),
+                Arrays.asList("coefficient", "modelVersion"),

Review Comment:
   Could also verify the model version in `testGetModelData`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r881314027


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);

Review Comment:
   That sounds good.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r881173528


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasL2.java:
##########
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared L2 param. */
+public interface HasL2<T> extends WithParams<T> {
+    Param<Double> L_2 = new DoubleParam("l2", "The l2 param.", 0.1, ParamValidators.gt(0.0));

Review Comment:
   done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r884756711


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);

Review Comment:
   I think the initial model version does need to be extracted from the input model data so that we can know the order of models across program restarts.
   
   Here is an example scenario:
   - Let's say the global batch size is 100 and there are 1000 records in the input datastream. The online training program should generate 10 model versions after processing these 1000 records.
   - The training process finished processing 300 records and generated 3 model data with versions 1, 2, and 3. After successfully making a checkpoint, the process exited due to machine failure.
   - The training process is restarted from the last successful checkpoint. It should continue to read input datastream starting from the 301th record. And it should read the latest model data generated before it is restarted.
   
   Ideally, we should hide the machine failure from users, meaning that the sequence of model versions should be 1, 2, 3, 4, ...10 as if the failure has never happened. Therefore we have to set the initial model version to the model version from the input model data.
   
   Does this make sense?
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r884756711


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);

Review Comment:
   I believe the model version does need to take into account the version of the input model data so that we can know the order of models across program restarts.
   
   Here is an example scenario:
   - Let's say the global batch size is 100 and there are 1000 records in the input datastream. The online training program should generate 10 model versions after processing these 1000 records.
   - The training process finished processing 300 records and generated 3 model data with versions 1, 2, and 3. After successfully making a checkpoint, the process exited due to machine failure.
   - The training process is restarted from the last successful checkpoint. It should continue to read input datastream starting from the 301th record. And it should read the latest model data generated before it is restarted.
   
   Ideally, we should hide the machine failure from users, meaning that the sequence of model versions should be 1, 2, 3, 4, ...10 as if the failure has never happened. Therefore we have to set the initial model version to the model version from the input model data.
   
   Does this make sense?
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r881258511


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);

Review Comment:
   I think it's not needed. The model version in this algorithm is from 1 and increasing every sink. The start version is 0 or 1 has no meaning.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r884843120


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);

Review Comment:
   Thanks for the explanation. Sounds good.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r882282398


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java:
##########
@@ -232,7 +232,7 @@ public void testSaveLoadAndPredict() throws Exception {
         LogisticRegressionModel model = logisticRegression.fit(binomialDataTable);
         model = StageTestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
         assertEquals(
-                Collections.singletonList("coefficient"),
+                Arrays.asList("coefficient", "modelVersion"),

Review Comment:
   Please keep this issue open if it is not resolved~



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r882371295


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBeta(), getReg(), getElasticNet());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     *
+     * <p>todo: makes ftrl to be a common optimizer and place it in org.apache.flink.ml.common in
+     * future.
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(
+                int batchSize, double alpha, double beta, double reg, double elasticNet) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData.map(new CreateLrModelData()).setParallelism(1);
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackModelData), DataStreamList.of(outputModelData));
+        }
+    }
+
+    private static class CreateLrModelData
+            implements MapFunction<DenseVector, LogisticRegressionModelData>, CheckpointedFunction {
+        private Long modelVersion = 1L;
+        private transient ListState<Long> modelVersionState;
+
+        @Override
+        public LogisticRegressionModelData map(DenseVector denseVector) throws Exception {
+            return new LogisticRegressionModelData(denseVector, modelVersion++);
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext functionSnapshotContext)
+                throws Exception {
+            modelVersionState.update(Collections.singletonList(modelVersion));
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) throws Exception {
+            modelVersionState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelVersionState", Long.class));
+        }
+    }
+
+    /** Updates model. */
+    private static class UpdateModel extends AbstractStreamOperator<DenseVector>
+            implements OneInputStreamOperator<DenseVector[], DenseVector> {
+        private ListState<double[]> nParamState;
+        private ListState<double[]> zParamState;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private double[] nParam;
+        private double[] zParam;
+
+        public UpdateModel(double alpha, double beta, double l1, double l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            nParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("nParamState", double[].class));
+            zParamState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("zParamState", double[].class));
+        }
+
+        @Override
+        public void processElement(StreamRecord<DenseVector[]> streamRecord) throws Exception {
+            DenseVector[] gradientInfo = streamRecord.getValue();
+            double[] coefficient = gradientInfo[2].values;
+            double[] g = gradientInfo[0].values;
+            for (int i = 0; i < g.length; ++i) {
+                if (gradientInfo[1].values[i] != 0.0) {
+                    g[i] = g[i] / gradientInfo[1].values[i];
+                }
+            }
+            if (zParam == null) {
+                zParam = new double[g.length];
+                nParam = new double[g.length];
+                nParamState.add(nParam);
+                zParamState.add(zParam);
+            }
+
+            for (int i = 0; i < zParam.length; ++i) {
+                double sigma = (Math.sqrt(nParam[i] + g[i] * g[i]) - Math.sqrt(nParam[i])) / alpha;
+                zParam[i] += g[i] - sigma * coefficient[i];
+                nParam[i] += g[i] * g[i];
+
+                if (Math.abs(zParam[i]) <= l1) {
+                    coefficient[i] = 0.0;
+                } else {
+                    coefficient[i] =
+                            ((zParam[i] < 0 ? -1 : 1) * l1 - zParam[i])
+                                    / ((beta + Math.sqrt(nParam[i])) / alpha + l2);
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector(coefficient)));
+        }
+    }
+
+    private static class CalculateLocalGradient extends AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<Row[], DenseVector, DenseVector[]> {
+        private ListState<DenseVector> modelDataState;
+        private ListState<Row[]> localBatchDataState;
+        private double[] gradient;
+        private double[] weight;
+        private int[] denseVectorIndices;
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", DenseVector.class));
+            TypeInformation<Row[]> type =
+                    ObjectArrayTypeInfo.getInfoFor(TypeInformation.of(Row.class));
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("localBatch", type));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row[]> pointsRecord) throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            calculateGradient();
+        }
+
+        private void calculateGradient() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+            DenseVector modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
+            modelDataState.clear();
+
+            List<Row[]> pointsList = IteratorUtils.toList(localBatchDataState.get().iterator());
+            Row[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Row point : points) {
+                Vector vec = point.getFieldAs(0);
+                double label = point.getFieldAs(1);
+                if (gradient == null) {
+                    gradient = new double[vec.size()];
+                    weight = new double[gradient.length];
+                    if (vec instanceof DenseVector) {
+                        denseVectorIndices = new int[vec.size()];
+                        for (int i = 0; i < denseVectorIndices.length; ++i) {
+                            denseVectorIndices[i] = i;
+                        }
+                    }
+                }
+
+                int[] indices;
+                double[] values;
+                if (vec instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) vec;
+                    indices = denseVectorIndices;
+                    values = denseVector.values;
+                } else {
+                    SparseVector sparseVector = (SparseVector) vec;
+                    indices = sparseVector.indices;
+                    values = sparseVector.values;
+                }
+                double p = 0.0;
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    p += modelData.values[idx] * values[i];
+                }
+                p = 1 / (1 + Math.exp(-p));
+                for (int i = 0; i < indices.length; ++i) {
+                    int idx = indices[i];
+                    gradient[idx] += (p - label) * values[i];
+                    weight[idx] += 1.0;
+                }
+            }
+
+            if (points.length > 0) {
+                output.collect(
+                        new StreamRecord<>(
+                                new DenseVector[] {
+                                    new DenseVector(gradient),
+                                    new DenseVector(weight),

Review Comment:
   The variable `weight` is a bit confusing here. Is `featureImportance` a better name? (I am not sure but this is the best name I can come up with)
   Also, could you add the Java doc here to say that if the input is dense, it is the same with Tensorflow implementation, while different for sparse input? It would be great if you can also explain why we do it differently for sparse input.
   
   
   Moreover, for outputing the models, we could use another `OutputTag<LogisticRegressionModelData>` and do not mess it up with the communication. Please check out SGD#Line292 for example.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r869977839


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +191,49 @@ public void snapshotState(StateSnapshotContext context) throws Exception {
             }
         }
     }
+
+    /**
+     * An operator that splits a global batch into evenly-sized local batches, and distributes them
+     * to downstream operator.
+     *
+     * @param <T> Data type of batch data.
+     */
+    public static class GlobalBatchSplitter<T> implements FlatMapFunction<T[], T[]> {

Review Comment:
   I think one function is better. I will refine it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r860545252


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasL2.java:
##########
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared L2 param. */
+public interface HasL2<T> extends WithParams<T> {
+    Param<Double> L_2 = new DoubleParam("l2", "The l2 param.", 0.1, ParamValidators.gt(0.0));

Review Comment:
   @weibozhao Currently the class `LogisticRegression` pass the value of `getReg()` to `LogisticGradient::l2`, which suggests that getReg() actually represents the hyper-parameter l2 for `LogisticRegression`.
   
   Could you explain why we can not get the value of `l2` from `getReg())` for `OnlineLogisticRegression`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r860633423


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasL2.java:
##########
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared L2 param. */
+public interface HasL2<T> extends WithParams<T> {
+    Param<Double> L_2 = new DoubleParam("l2", "The l2 param.", 0.1, ParamValidators.gt(0.0));

Review Comment:
   According to the Elastic net regularization wiki page [1], l1 and l2 could be theoretically translated to a regularization constant and the SVM regularization constant (i.e. the elastic net). Thus the regularization constant is different from l1 and l2 under this context
   
   As far as the mathematical formula is concerned, is the existing `HasReg::REG` the same as L2 or the regularization constant explained in [1]? If it is the same as L2, we probably should rename this class to e.g. `HasL2`. Otherwise, we probably should rename the variable `LogisticGradient::l2` as appropriate.
   
   [1] https://en.wikipedia.org/wiki/Elastic_net_regularization#:~:text=In%20statistics%20and%2C%20in%20particular,the%20lasso%20and%20ridge%20methods.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on PR #83:
URL: https://github.com/apache/flink-ml/pull/83#issuecomment-1141176710

   Thanks for the update! LGTM.
   
   @zhipeng93 Do you want to take another look before merging this PR?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r882428114


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);

Review Comment:
   What's the meaning of model version larger than the init version. For an online train algorithm, we just want to know the order of models outputing. @zhipeng93 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r881173256


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the OnlineLogisticRegression algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Online_machine_learning.

Review Comment:
   I have refine it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r881623847


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasReg<T>,
+                HasElasticNet<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =
+            new DoubleParam("alpha", "The parameter alpha of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getAlpha() {
+        return get(ALPHA);
+    }
+
+    default T setAlpha(Double value) {
+        return set(ALPHA, value);
+    }
+
+    Param<Double> BETA =
+            new DoubleParam("alpha", "The parameter beta of ftrl.", 0.1, ParamValidators.gt(0.0));
+
+    default Double getBeta() {
+        return get(BETA);
+    }
+
+    default T setBeta(Double value) {
+        return set(BETA, value);
+    }
+
+    Param<Integer> MODEL_SAVE_INTERVAL =
+            new IntParam(
+                    "modelSaveInterval",
+                    "The iteration steps between two output models.",

Review Comment:
   I will remove this parameter just as online kmeans。



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] weibozhao commented on a diff in pull request #83: [FLINK-27170] Add Transformer and Estimator for OnlineLogisticRegression

Posted by GitBox <gi...@apache.org>.
weibozhao commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r881266473


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,434 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.FilterFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the FTRL-Proximal online learning algorithm proposed by H. Brendan
+ * McMahan et al.
+ *
+ * <p>See <a href="https://doi.org/10.1145/2487575.2488200">H. Brendan McMahan et al., Ad click
+ * prediction: a view from the trenches.</a>
+ */
+public class OnlineLogisticRegression
+        implements Estimator<OnlineLogisticRegression, OnlineLogisticRegressionModel>,
+                OnlineLogisticRegressionParams<OnlineLogisticRegression> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineLogisticRegression() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public OnlineLogisticRegressionModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<LogisticRegressionModelData> modelDataStream =
+                LogisticRegressionModelData.getModelDataStream(initModelDataTable);
+
+        DataStream<Row> points =
+                tEnv.toDataStream(inputs[0])
+                        .map(new FeaturesExtractor(getFeaturesCol(), getLabelCol()));
+
+        DataStream<DenseVector> initModelData =
+                modelDataStream.map(
+                        (MapFunction<LogisticRegressionModelData, DenseVector>)
+                                value -> value.coefficient);
+
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(),
+                        getAlpha(),
+                        getBeta(),
+                        getReg(),
+                        getElasticNet(),
+                        getModelSaveInterval());
+
+        DataStream<LogisticRegressionModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        OnlineLogisticRegressionModel model =
+                new OnlineLogisticRegressionModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, Row> {
+        private final String featuresCol;
+        private final String labelCol;
+
+        private FeaturesExtractor(String featuresCol, String labelCol) {
+            this.featuresCol = featuresCol;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            return Row.of(row.getField(featuresCol), row.getField(labelCol));
+        }
+    }
+
+    /**
+     * Implementation of ftrl optimizer. In this implementation, gradients are calculated in
+     * distributed workers and reduce to one gradient. The reduced gradient is used to update model
+     * by ftrl method.
+     *
+     * <p>See https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl
+     */
+    private static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private long modelVersion = 1L;
+        private final int modelSaveInterval;
+
+        public FtrlIterationBody(
+                int batchSize,
+                double alpha,
+                double beta,
+                double reg,
+                double elasticNet,
+                int modelSaveInterval) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = elasticNet * reg;
+            this.l2 = (1 - elasticNet) * reg;
+            this.modelSaveInterval = modelSaveInterval;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector> modelData = variableStreams.get(0);
+
+            DataStream<Row> points = dataStreams.get(0);
+            int parallelism = points.getParallelism();
+            Preconditions.checkState(
+                    parallelism <= batchSize,
+                    "There are more subtasks in the training process than the number "
+                            + "of elements in each batch. Some subtasks might be idling forever.");
+
+            DataStream<DenseVector[]> newGradient =
+                    DataStreamUtils.generateBatchData(points, parallelism, batchSize)
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "LocalGradientCalculator",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new CalculateLocalGradient())
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(
+                                    (ReduceFunction<DenseVector[]>)
+                                            (gradientInfo, newGradientInfo) -> {
+                                                for (int i = 0;
+                                                        i < newGradientInfo[1].size();
+                                                        ++i) {
+                                                    newGradientInfo[0].values[i] =
+                                                            gradientInfo[0].values[i]
+                                                                    + newGradientInfo[0].values[i];
+                                                    newGradientInfo[1].values[i] =
+                                                            gradientInfo[1].values[i]
+                                                                    + newGradientInfo[1].values[i];
+                                                    if (newGradientInfo[2] == null) {
+                                                        newGradientInfo[2] = gradientInfo[2];
+                                                    }
+                                                }
+                                                return newGradientInfo;
+                                            });
+            DataStream<DenseVector> feedbackModelData =
+                    newGradient
+                            .transform(
+                                    "ModelDataUpdater",
+                                    TypeInformation.of(DenseVector.class),
+                                    new UpdateModel(alpha, beta, l1, l2))
+                            .setParallelism(1);
+
+            DataStream<LogisticRegressionModelData> outputModelData =
+                    feedbackModelData
+                            .filter(
+                                    new FilterFunction<DenseVector>() {
+                                        private int step = 0;
+
+                                        @Override
+                                        public boolean filter(DenseVector denseVector) {
+                                            step++;
+                                            return step % modelSaveInterval == 0;
+                                        }
+                                    })
+                            .setParallelism(1)

Review Comment:
   I think is not needed now. If more than one algorithm use this part of code, we should extract it to a function or class. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org