You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/04/25 06:56:29 UTC

[GitHub] [flink-ml] zhipeng93 opened a new pull request, #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

zhipeng93 opened a new pull request, #90:
URL: https://github.com/apache/flink-ml/pull/90

   ## What is the purpose of the change
   - Add Transformer and Estimator of LinearRegression in FlinkML. 
   
   ## Brief change log
   - Abstracted a base class for general linear models.
   - Reconstructed the implementation of LogsiticRegression with the proposed base class.
   - Added Transformer and Estimator of LinearRegression.
   - Added unit test for Transformer and Estimator of LinearRegression.
   - Added `HasElasticNet` param for linear models.
   - Added `WithRegularization` for regularization terms.
   - Added one method axpy(vec x, DenseVec y, int k) in BLAS and corresponding unit test.
   
   ## Does this pull request potentially affect one of the following parts:
   - Dependencies (does it add or upgrade a dependency): (no)
   - The public API, i.e., is any changed class annotated with @Public(Evolving): (no)
   - Does this pull request introduce a new feature? (no)
   - If yes, how is the feature documented? (Java doc)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858415540


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/LocalTrainer.java:
##########
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+
+import java.util.Iterator;
+
+/**
+ * A local trainer is a trainer that uses a batch of training data to compute a update of the model
+ * locally.
+ *
+ * @param <T> Class type of training data.
+ */
+public abstract class LocalTrainer<T> implements CheckpointedFunction {

Review Comment:
   Thanks for the insightful comments.
   
   > move CacheDataAndDoTrain to an independent class
   
   It is an dependent class in the old PR. Do you mean to merge it with LocalTrainer and WithRegularization?
   
   > merge LocalTrainer and WithRegularization and remove methods like getReg, as now they are only used internally
   
   It seems hard to merge these two class, because the logic of `WithRegularization` is supposed to be used in `LocalTrainer#updateModel`. I have renamed `WithRegularization` as `RegularizationUtils` and removed methods like `withReg`. What do you think?
   
   > rename the merged class. A name like LocalTrainer might be too general to be associated with linear algorithms.
   
   I have renamed the class to `LocalLinearTrainer`. I still think it is not a very good name. We could probably discuss more on the naming.
   
   > merge CacheDataAndDoTrain, LocalTrainer and WithRegularization.
   
   I did not merge `CacheDataAndDoTrain` with the other two for now for the following reasons:
   - `CacheDataAndDoTrain` is a more a infra and involves distributed concepts, like two input operators.
   - `LocalTrainer` is a more friendly and clean concept for machine learning users, since it only involves local operations.
   
   > change the API and implementation of trainOnBatchData. I have the sense that each time only one data, instead of a batch of data, is enough. I think the API of org.apache.flink.api.common.functions.AggregateFunction is a good reference.
   
   I did not do the change for the following reasons:
   - mini-batch training is a common concept for machine learning.
   - Users may want to do operations before/after mini-batch training.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -42,12 +43,9 @@
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
-public class LogisticRegressionModelData {
-
-    public DenseVector coefficient;
-
+public class LogisticRegressionModelData extends GeneralLinearAlgoModelData {

Review Comment:
   I also tried to do this but failed because in `ModelDataEncoder` we need to construct an instance of `LinearRegressionModelData` and `LogisticRegressionModelData `. If we pass a class type, we may need to go with reflections, which is usually not encouraged.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java:
##########
@@ -18,51 +18,28 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
-import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.ml.api.Model;
-import org.apache.flink.ml.common.broadcast.BroadcastUtils;
-import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.glm.GeneralLinearAlgoModel;
 import org.apache.flink.ml.linalg.BLAS;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
-import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-import org.apache.flink.table.api.internal.TableImpl;
 import org.apache.flink.types.Row;
-import org.apache.flink.util.Preconditions;
 
 import org.apache.commons.lang3.ArrayUtils;
 
 import java.io.IOException;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
+import java.io.Serializable;
 
-/** A Model which classifies data using the model data computed by {@link LogisticRegression}. */
-public class LogisticRegressionModel
-        implements Model<LogisticRegressionModel>,
-                LogisticRegressionModelParams<LogisticRegressionModel> {
-
-    private final Map<Param<?>, Object> paramMap = new HashMap<>();
-
-    private Table modelDataTable;
-
-    public LogisticRegressionModel() {
-        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
-    }
-
-    @Override
-    public Map<Param<?>, Object> getParamMap() {
-        return paramMap;
-    }
+/**
+ * A Model which classifies the input using the model data computed by {@link LogisticRegression}.
+ */
+public class LogisticRegressionModel extends GeneralLinearAlgoModel<LogisticRegressionModel>
+        implements LogisticRegressionModelParams<LogisticRegressionModel>, Serializable {
 
     @Override
     public void save(String path) throws IOException {

Review Comment:
   The change seems infeasible for now...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860683058


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -67,6 +72,28 @@ public static <IN, OUT> DataStream<OUT> mapPartition(
                 .setParallelism(input.getParallelism());
     }
 
+    /**
+     * Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most
+     * one stream record and its parallelism is one.
+     *
+     * @param input The input data stream.
+     * @param func The user defined reduce function.
+     * @param <T> The class type of the input.
+     * @return The result data stream.
+     */
+    public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) {
+        DataStream<T> partialReducedStream =
+                input.transform("partialReduce", input.getType(), new ReduceOperator<>(func))
+                        .setParallelism(input.getParallelism());

Review Comment:
   Yeah, you are right and it is not chained. But from the job graph, I could see that the edge is already `FORWARD` and there is no network shuffle.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858401503


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/GeneralLinearAlgo.java:
##########
@@ -0,0 +1,414 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.logisticregression.BinaryLogisticTrainer;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionTrainer;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Base class for general linear machine learning models.
+ *
+ * @param <E> Class type of {@link Estimator}.
+ * @param <M> Class type of {@link Model}.
+ */
+public abstract class GeneralLinearAlgo<

Review Comment:
   Sounds great. The PR is updated according to the comments.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860433041


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ *
+ * @param <ParamType> Type of the optimizer-related parameter.
+ */
+@Internal
+public abstract class Optimizer<ParamType> {
+    /**
+     * Optimize the given loss function using the init model and the training data.
+     *
+     * @param bcInitModel The broadcasted init model. Note that each task contains one DenseVector

Review Comment:
   nit: broadcast



##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java:
##########
@@ -34,10 +34,16 @@ public static double asum(DenseVector x) {
     /** y += a * x . */
     public static void axpy(double a, Vector x, DenseVector y) {
         Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched.");
+        axpy(a, x, y, x.size());
+    }
+
+    /** y += a * x for the first k dimensions, with the other dimensions unchanged. */
+    public static void axpy(double a, Vector x, DenseVector y, int k) {

Review Comment:
   In which cases do we need to do computation on only the first `k` dimensions?



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java:
##########
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModel;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModelData;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+/** Tests {@link LinearRegression} and {@link LinearRegressionModel}. */
+public class LinearRegressionTest {
+
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private StreamExecutionEnvironment env;
+
+    private StreamTableEnvironment tEnv;
+
+    private static final List<Row> trainData =
+            Arrays.asList(
+                    Row.of(Vectors.dense(2, 1), 4.0, 1.0),
+                    Row.of(Vectors.dense(3, 2), 7.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 4), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 2), 6.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(1, 2), 5.0, 1.0),
+                    Row.of(Vectors.dense(5, 3), 11.0, 1.0));
+
+    private static final double[] expectedCoefficient = new double[] {1.0, 2.0};
+
+    private static final double TOLERANCE = 1e-7;
+
+    private Table trainDataTable;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Collections.shuffle(trainData);
+        trainDataTable =
+                tEnv.fromDataStream(
+                        env.fromCollection(
+                                trainData,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class),

Review Comment:
   nit: `DenseVectorTypeInfo.INSTANCE`



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java:
##########
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModel;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModelData;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+/** Tests {@link LinearRegression} and {@link LinearRegressionModel}. */
+public class LinearRegressionTest {
+
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private StreamExecutionEnvironment env;
+
+    private StreamTableEnvironment tEnv;
+
+    private static final List<Row> trainData =
+            Arrays.asList(
+                    Row.of(Vectors.dense(2, 1), 4.0, 1.0),
+                    Row.of(Vectors.dense(3, 2), 7.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 4), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 2), 6.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(1, 2), 5.0, 1.0),
+                    Row.of(Vectors.dense(5, 3), 11.0, 1.0));
+
+    private static final double[] expectedCoefficient = new double[] {1.0, 2.0};
+
+    private static final double TOLERANCE = 1e-7;
+
+    private Table trainDataTable;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Collections.shuffle(trainData);
+        trainDataTable =
+                tEnv.fromDataStream(
+                        env.fromCollection(
+                                trainData,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class),
+                                            Types.DOUBLE,
+                                            Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label", "weight"})));
+    }
+
+    @SuppressWarnings("unchecked")
+    private void verifyPredictionResult(
+            Table output, String labelCol, String weightCol, String predictionCol)
+            throws Exception {
+        List<Row> predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        double lossSum = 0;
+        for (Row predictionRow : predResult) {
+            double label = (double) predictionRow.getField(labelCol);
+            double prediction = (double) predictionRow.getField(predictionCol);
+            double weight = (double) predictionRow.getField(weightCol);
+            lossSum += weight * Math.pow(label - prediction, 2);
+        }
+        assertTrue(lossSum < 1.0);

Review Comment:
   I'm not sure whether this criteria can guarantee the correctness of this algorithm. Is it possible to provide an estimation of the expected prediction result, and tests that the actual prediction should be close enough to the expected value, for example using `expectedCoefficient`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasElasticNetParam.java:
##########
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Interface for the shared elasticNet param, which specifies the mixing of L1 and L2 penalty:
+ * <li>If the value is zero, it is L2 penalty.
+ * <li>If the value is one, it is L1 penalty.
+ * <li>For value in (0,1), it is a combination of L1 and L2 penalty.

Review Comment:
   nit: add `<ul></ul>` around the `<li>` tags could make the JavaDoc's rendered results better.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java:
##########
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.DenseVector;
+
+import java.io.Serializable;
+
+/**
+ * A loss function is to compute the loss and gradient with the given coefficient and training data.
+ */
+@Internal
+public interface LossFunc extends Serializable {

Review Comment:
   Would it be better to merge `LossFunc` and its implementations into `optimizer` package? Would `LossFunc` be used in places other than `Optimizer` subclasses?



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java:
##########
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.ml.linalg.DenseVector;
+
+import org.apache.commons.lang3.RandomUtils;
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/** Tests {@link RegularizationUtils}. */
+public class RegularizationUtilsTest {
+
+    private static final double learningRate = 0.1;
+    private static final double TOLERANCE = 1e-7;
+    private static final DenseVector coefficient = new DenseVector(new double[] {1.0, -2.0, 0});
+
+    @Test
+    public void testReg0() {

Review Comment:
   The test cases in this class seems to be only different in their input values. Moving them to a single test case might be of better readability.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasElasticNetParam.java:
##########
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Interface for the shared elasticNet param, which specifies the mixing of L1 and L2 penalty:
+ * <li>If the value is zero, it is L2 penalty.
+ * <li>If the value is one, it is L1 penalty.
+ * <li>For value in (0,1), it is a combination of L1 and L2 penalty.
+ */
+public interface HasElasticNetParam<T> extends WithParams<T> {

Review Comment:
   nit: `HasElasticNet` seems better in accordance with other class names.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasElasticNetParam.java:
##########
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Interface for the shared elasticNet param, which specifies the mixing of L1 and L2 penalty:
+ * <li>If the value is zero, it is L2 penalty.
+ * <li>If the value is one, it is L1 penalty.
+ * <li>For value in (0,1), it is a combination of L1 and L2 penalty.
+ */
+public interface HasElasticNetParam<T> extends WithParams<T> {
+    Param<Double> ELASTICNET =

Review Comment:
   nit: `ELASTIC_NET` seems better in accordance with others.



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -67,6 +72,28 @@ public static <IN, OUT> DataStream<OUT> mapPartition(
                 .setParallelism(input.getParallelism());
     }
 
+    /**
+     * Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most
+     * one stream record and its parallelism is one.
+     *
+     * @param input The input data stream.
+     * @param func The user defined reduce function.
+     * @param <T> The class type of the input.
+     * @return The result data stream.
+     */
+    public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) {
+        DataStream<T> partialReducedStream =
+                input.transform("partialReduce", input.getType(), new ReduceOperator<>(func))
+                        .setParallelism(input.getParallelism());

Review Comment:
   It might be better to use `forward()` and `colocate()` here so that the reduce is of better performance.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java:
##########
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+
+/** The loss function for linear regression. See {@link LinearRegression} */
+@Internal
+public class LeastSquareLoss implements LossFunc {

Review Comment:
   It seems that this class can be used in any algorithm that uses the least square method. Maybe we can refractor its Javadoc to reflect this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860631049


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java:
##########
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.DenseVector;
+
+import java.io.Serializable;
+
+/**
+ * A loss function is to compute the loss and gradient with the given coefficient and training data.
+ */
+@Internal
+public interface LossFunc extends Serializable {

Review Comment:
   `LossFunc` and `Optmizer` are two orthogonal concepts. For `LossFunc`, there could be logistic loss, least square loss, hinge loss, etc. For `Optimizer`, there could be sgd, adam, l-bfgs. 
   
   Given a loss function, we can use different optimizers. Given a optimizer, we can optimize different loss functions.
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865540818


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.
+     *
+     * @param bcInitModel The broadcast init model. Note that each task contains one DenseVector as

Review Comment:
   `broadcast` and `broadcasted` are both ok here. [1]
   
   I have updated `task` to `partition`.
   
   
   [1] https://www.usingenglish.com/reference/irregular-verbs/broadcast.html



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865898564


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.
+     *
+     * @param bcInitModel The broadcast init model. Note that each task contains one DenseVector as

Review Comment:
   Thanks for the explanation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860612197


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -67,6 +72,28 @@ public static <IN, OUT> DataStream<OUT> mapPartition(
                 .setParallelism(input.getParallelism());
     }
 
+    /**
+     * Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most
+     * one stream record and its parallelism is one.
+     *
+     * @param input The input data stream.
+     * @param func The user defined reduce function.
+     * @param <T> The class type of the input.
+     * @return The result data stream.
+     */
+    public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) {
+        DataStream<T> partialReducedStream =
+                input.transform("partialReduce", input.getType(), new ReduceOperator<>(func))
+                        .setParallelism(input.getParallelism());

Review Comment:
   Shouldn't these two operators be chained together?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860608492


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java:
##########
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModel;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModelData;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+/** Tests {@link LinearRegression} and {@link LinearRegressionModel}. */
+public class LinearRegressionTest {
+
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private StreamExecutionEnvironment env;
+
+    private StreamTableEnvironment tEnv;
+
+    private static final List<Row> trainData =
+            Arrays.asList(
+                    Row.of(Vectors.dense(2, 1), 4.0, 1.0),
+                    Row.of(Vectors.dense(3, 2), 7.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 4), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 2), 6.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(1, 2), 5.0, 1.0),
+                    Row.of(Vectors.dense(5, 3), 11.0, 1.0));
+
+    private static final double[] expectedCoefficient = new double[] {1.0, 2.0};
+
+    private static final double TOLERANCE = 1e-7;
+
+    private Table trainDataTable;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Collections.shuffle(trainData);
+        trainDataTable =
+                tEnv.fromDataStream(
+                        env.fromCollection(
+                                trainData,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class),
+                                            Types.DOUBLE,
+                                            Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label", "weight"})));
+    }
+
+    @SuppressWarnings("unchecked")
+    private void verifyPredictionResult(
+            Table output, String labelCol, String weightCol, String predictionCol)
+            throws Exception {
+        List<Row> predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        double lossSum = 0;
+        for (Row predictionRow : predResult) {
+            double label = (double) predictionRow.getField(labelCol);
+            double prediction = (double) predictionRow.getField(predictionCol);
+            double weight = (double) predictionRow.getField(weightCol);
+            lossSum += weight * Math.pow(label - prediction, 2);
+        }
+        assertTrue(lossSum < 1.0);

Review Comment:
   It is a regression task, so the prediction is NOT a label, but a double value. It is not possible to assert that each prediction is correct as a label.
   
   I have refined the function and assert that the average loss is smaller than 0.1.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865893645


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -67,6 +72,28 @@ public static <IN, OUT> DataStream<OUT> mapPartition(
                 .setParallelism(input.getParallelism());
     }
 
+    /**
+     * Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most
+     * one stream record and its parallelism is one.
+     *
+     * @param input The input data stream.
+     * @param func The user defined reduce function.
+     * @param <T> The class type of the input.
+     * @return The result data stream.
+     */
+    public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) {

Review Comment:
   You are right. Thanks for the explanation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865908793


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java:
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression.linearregression;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which predicts data using the model data computed by {@link LinearRegression}. */
+public class LinearRegressionModel
+        implements Model<LinearRegressionModel>,
+                LinearRegressionModelParams<LinearRegressionModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    private Table modelDataTable;
+
+    public LinearRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {

Review Comment:
   Sounds great. Thanks for the update.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865539963


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.

Review Comment:
   I think we probably should not the restrict that the `trainData` is bounded, since we could also implement online optimizer here.
   
   What do you think?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on PR #90:
URL: https://github.com/apache/flink-ml/pull/90#issuecomment-1118258266

   Hi @lindong28 thanks for the review, I have addressed your comments in the lastest PR.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r866020158


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java:
##########
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.DenseVector;
+
+import java.io.Serializable;
+
+/**
+ * A loss function is to compute the loss and gradient with the given coefficient and training data.
+ */
+@Internal
+public interface LossFunc extends Serializable {
+
+    /**
+     * Computes the loss on the given data point.
+     *
+     * @param dataPoint A training data point.
+     * @param coefficient The model parameters.
+     * @return The loss of the input data.
+     */
+    double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient);
+
+    /**
+     * Computes the gradient on the given data point.

Review Comment:
   nits: It is not clear what is the side-effective of invoking this method.
   
   Could we update the Java doc so that it is clear this method adds the gradient to `cumGradient`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858421894


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java:
##########
@@ -18,51 +18,28 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
-import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.ml.api.Model;
-import org.apache.flink.ml.common.broadcast.BroadcastUtils;
-import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.glm.GeneralLinearAlgoModel;
 import org.apache.flink.ml.linalg.BLAS;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
-import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-import org.apache.flink.table.api.internal.TableImpl;
 import org.apache.flink.types.Row;
-import org.apache.flink.util.Preconditions;
 
 import org.apache.commons.lang3.ArrayUtils;
 
 import java.io.IOException;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
+import java.io.Serializable;
 
-/** A Model which classifies data using the model data computed by {@link LogisticRegression}. */
-public class LogisticRegressionModel
-        implements Model<LogisticRegressionModel>,
-                LogisticRegressionModelParams<LogisticRegressionModel> {
-
-    private final Map<Param<?>, Object> paramMap = new HashMap<>();
-
-    private Table modelDataTable;
-
-    public LogisticRegressionModel() {
-        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
-    }
-
-    @Override
-    public Map<Param<?>, Object> getParamMap() {
-        return paramMap;
-    }
+/**
+ * A Model which classifies the input using the model data computed by {@link LogisticRegression}.
+ */
+public class LogisticRegressionModel extends GeneralLinearAlgoModel<LogisticRegressionModel>
+        implements LogisticRegressionModelParams<LogisticRegressionModel>, Serializable {
 
     @Override
     public void save(String path) throws IOException {

Review Comment:
   The change seems infeasible for now...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865976547


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.
+         *
+         * @return The loss sum.
+         */
+        private double getLoss() {
+            return feedbackArray[coeffiDim + 1];
+        }
+
+        /**
+         * Sets the loss sum of the processed elements.
+         *
+         * @param loss The loss sum.
+         */
+        private void setLoss(double loss) {
+            feedbackArray[coeffiDim + 1] = loss;
+        }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<double[]> collector)
+                throws Exception {
+            if (epochWatermark == 0) {
+                coefficient = new DenseVector(feedbackArray);
+                coeffiDim = coefficient.size();
+                feedbackArray = new double[coefficient.size() + 2];
+            } else {
+                if (getWeightSum() > 0) {
+                    BLAS.axpy(
+                            -params.learningRate / getWeightSum(),
+                            new DenseVector(feedbackArray),
+                            coefficient,
+                            coeffiDim);
+                    double regLoss =
+                            RegularizationUtils.regularize(
+                                    coefficient,
+                                    params.reg,
+                                    params.elasticNet,
+                                    params.learningRate);
+                    setLoss(getLoss() + regLoss);
+                }
+            }
+
+            if (trainData == null) {
+                trainData = IteratorUtils.toList(trainDataState.get().iterator());
+            }
+
+            // TODO: supports efficient shuffle of training set on each partition.
+            if (trainData.size() > 0) {
+                List<LabeledPointWithWeight> miniBatchData =
+                        trainData.subList(
+                                nextBatchOffset,
+                                Math.min(nextBatchOffset + localBatchSize, trainData.size()));

Review Comment:
   Sounds good. Thanks for the explanation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r866422926


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.

Review Comment:
   Sure, I will modify the Java doc to explain that we only support `bounded trainData`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858415540


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/LocalTrainer.java:
##########
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+
+import java.util.Iterator;
+
+/**
+ * A local trainer is a trainer that uses a batch of training data to compute a update of the model
+ * locally.
+ *
+ * @param <T> Class type of training data.
+ */
+public abstract class LocalTrainer<T> implements CheckpointedFunction {

Review Comment:
   > move CacheDataAndDoTrain to an independent class
   
   It is an dependent class in the old PR. Do you mean to merge it with LocalTrainer and WithRegularization?
   
   > merge LocalTrainer and WithRegularization and remove methods like getReg, as now they are only used internally
   
   It seems hard to merge these two class, because the logic of `WithRegularization` is supposed to be used in `LocalTrainer#updateModel`. I have renamed `WithRegularization` as `RegularizationUtils` and removed methods like `withReg`. What do you think?
   
   > rename the merged class. A name like LocalTrainer might be too general to be associated with linear algorithms.
   
   I have renamed the class to `LocalLinearTrainer`. I still think it is not a very good name. We could probably discuss more on the naming.
   
   > merge CacheDataAndDoTrain, LocalTrainer and WithRegularization.
   
   I did not merge `CacheDataAndDoTrain` with the other two for now for the following reasons:
   - `CacheDataAndDoTrain` is a more a infra and involves distributed concepts, like two input operators.
   - `LocalTrainer` is a more friendly and clean concept for machine learning users, since it only involves local operations.
   
   > change the API and implementation of trainOnBatchData. I have the sense that each time only one data, instead of a batch of data, is enough. I think the API of org.apache.flink.api.common.functions.AggregateFunction is a good reference.
   
   I did not do the change for the following reason:
   - mini-batch training is a common concept for machine learning.
   - Users may want to do operations before/after mini-batch training.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860652423


##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java:
##########
@@ -34,10 +34,16 @@ public static double asum(DenseVector x) {
     /** y += a * x . */
     public static void axpy(double a, Vector x, DenseVector y) {
         Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched.");
+        axpy(a, x, y, x.size());
+    }
+
+    /** y += a * x for the first k dimensions, with the other dimensions unchanged. */
+    public static void axpy(double a, Vector x, DenseVector y, int k) {

Review Comment:
   Now it is used for computing loss and update model in `BinaryLogisticLoss`, `LeastSquareLoss` and `SGD`. It is used to avoid create a new double array instance to improve the performance.
   
   Let's see how others think. The plan B could be: we remove this method in `BLAS` and add an internal utility function.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865897261


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java:
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/**
+ * A utility class for algorithms that need to handle regularization. The regularization term is
+ * defined as:
+ *
+ * <p>elasticNet * reg * norm1(coefficient) + (1 - elasticNet) * (reg/2) * (norm2(coefficient))^2
+ *
+ * <p>See https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html.
+ */
+@Internal
+class RegularizationUtils {
+
+    /**
+     * Regularize the model coefficient. The gradient of each dimension could be computed as:
+     * {elasticNet * reg * Math.sign(c_i) + (1 - elasticNet) * reg * c_i}. Here c_i is the value of
+     * coefficient at i-th dimension.
+     *
+     * @param coefficient The model coefficient.
+     * @param reg The reg param.
+     * @param elasticNet The elasticNet param.
+     * @param learningRate The learningRate param.
+     * @return The loss introduced by regularization.
+     */
+    public static double regularize(
+            DenseVector coefficient,
+            final double reg,
+            final double elasticNet,
+            final double learningRate) {
+
+        if (Double.compare(reg, 0) == 0) {
+            return 0;
+        } else {
+            if (Double.compare(elasticNet, 0) == 0) {
+                // Only L2 regularization.
+                double loss = reg / 2 * BLAS.norm2(coefficient);
+                BLAS.scal(1 - learningRate * reg, coefficient);
+                return loss;
+            } else if (Double.compare(elasticNet, 1) == 0) {
+                // Only L1 regularization.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    if (Double.compare(coefficientArray[i], 0) == 0) {
+                        continue;
+                    }
+                    loss += elasticNet * reg * Math.signum(coefficientArray[i]);
+                    coefficientArray[i] -=
+                            learningRate * elasticNet * reg * Math.signum(coefficientArray[i]);
+                }
+                return loss;
+            } else {
+                // Both L1 and L2 are not zero.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    loss +=

Review Comment:
   Sounds good. Thanks for the explanation.
   
   Should we remove the code which is commented out?
   
   ```
   // if (Double.compare(coefficientArray[i], 0) == 0) {
   //    continue;
   // }
   ```
   
   And would it make the code a bit more compact by updating the line 55 to use `else if (Double.compare(elasticNet, 0) == 0)`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r866424812


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {

Review Comment:
   Thanks for the comment:) I think `totalWeight` and `totalLoss` is better and will do the change.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on pull request #90: [FLINK-27093] Add Transformer and Estimator for LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on PR #90:
URL: https://github.com/apache/flink-ml/pull/90#issuecomment-1119358841

   Thanks for the update. LGTM.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858401503


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/GeneralLinearAlgo.java:
##########
@@ -0,0 +1,414 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.logisticregression.BinaryLogisticTrainer;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionTrainer;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Base class for general linear machine learning models.
+ *
+ * @param <E> Class type of {@link Estimator}.
+ * @param <M> Class type of {@link Model}.
+ */
+public abstract class GeneralLinearAlgo<

Review Comment:
   Thanks for the comment. I have re-organized the PR and extracted `Optimizer` and `LossFunc`. Now there is no `LocalLearner` or `LinearModelType` anymore.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858421894


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java:
##########
@@ -18,51 +18,28 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
-import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.ml.api.Model;
-import org.apache.flink.ml.common.broadcast.BroadcastUtils;
-import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.glm.GeneralLinearAlgoModel;
 import org.apache.flink.ml.linalg.BLAS;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
-import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-import org.apache.flink.table.api.internal.TableImpl;
 import org.apache.flink.types.Row;
-import org.apache.flink.util.Preconditions;
 
 import org.apache.commons.lang3.ArrayUtils;
 
 import java.io.IOException;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
+import java.io.Serializable;
 
-/** A Model which classifies data using the model data computed by {@link LogisticRegression}. */
-public class LogisticRegressionModel
-        implements Model<LogisticRegressionModel>,
-                LogisticRegressionModelParams<LogisticRegressionModel> {
-
-    private final Map<Param<?>, Object> paramMap = new HashMap<>();
-
-    private Table modelDataTable;
-
-    public LogisticRegressionModel() {
-        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
-    }
-
-    @Override
-    public Map<Param<?>, Object> getParamMap() {
-        return paramMap;
-    }
+/**
+ * A Model which classifies the input using the model data computed by {@link LogisticRegression}.
+ */
+public class LogisticRegressionModel extends GeneralLinearAlgoModel<LogisticRegressionModel>
+        implements LogisticRegressionModelParams<LogisticRegressionModel>, Serializable {
 
     @Override
     public void save(String path) throws IOException {

Review Comment:
   The change seems infeasible for me now...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858419504


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -42,12 +43,9 @@
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
-public class LogisticRegressionModelData {
-
-    public DenseVector coefficient;
-
+public class LogisticRegressionModelData extends GeneralLinearAlgoModelData {

Review Comment:
   I also tried to do this but failed because in `ModelDataEncoder` we need to construct an instance of `LinearRegressionModelData` and `LogisticRegressionModelData `. If we pass a class type, we may need to go with reflections, which is usually not encouraged.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860633366


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -67,6 +72,28 @@ public static <IN, OUT> DataStream<OUT> mapPartition(
                 .setParallelism(input.getParallelism());
     }
 
+    /**
+     * Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most
+     * one stream record and its parallelism is one.
+     *
+     * @param input The input data stream.
+     * @param func The user defined reduce function.
+     * @param <T> The class type of the input.
+     * @return The result data stream.
+     */
+    public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) {
+        DataStream<T> partialReducedStream =
+                input.transform("partialReduce", input.getType(), new ReduceOperator<>(func))
+                        .setParallelism(input.getParallelism());

Review Comment:
   So far as I can see from the `JobGraph` of `LinearRegressionTest.testFitAndPredict`, `partialReduce` is a separate job vertex that is not chained with others yet.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on PR #90:
URL: https://github.com/apache/flink-ml/pull/90#issuecomment-1110903992

   @yunfengzhou-hub Thanks for the insightful comments.
   After some digging into the Spark/Alink code, I just relized that different linear models may also have different model data. So I gave up on the base linear model.
   
   I have re-organized the PR and extracted `Optimizer` and `LossFunc`. Moreover, I have created `SGD` optimizer and loss function to extract the common logic. Please take a look.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r866427486


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java:
##########
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.DenseVector;
+
+import java.io.Serializable;
+
+/**
+ * A loss function is to compute the loss and gradient with the given coefficient and training data.
+ */
+@Internal
+public interface LossFunc extends Serializable {
+
+    /**
+     * Computes the loss on the given data point.
+     *
+     * @param dataPoint A training data point.
+     * @param coefficient The model parameters.
+     * @return The loss of the input data.
+     */
+    double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient);
+
+    /**
+     * Computes the gradient on the given data point.

Review Comment:
   Thanks for the comment. The java doc is updated as follows:
   `Computes the gradient on the given data point and adds the computed gradient to cumGradient.`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 merged pull request #90: [FLINK-27093] Add Transformer and Estimator for LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 merged PR #90:
URL: https://github.com/apache/flink-ml/pull/90


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865539013


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java:
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/**
+ * A utility class for algorithms that need to handle regularization. The regularization term is
+ * defined as:
+ *
+ * <p>elasticNet * reg * norm1(coefficient) + (1 - elasticNet) * (reg/2) * (norm2(coefficient))^2
+ *
+ * <p>See https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html.
+ */
+@Internal
+class RegularizationUtils {
+
+    /**
+     * Regularize the model coefficient. The gradient of each dimension could be computed as:
+     * {elasticNet * reg * Math.sign(c_i) + (1 - elasticNet) * reg * c_i}. Here c_i is the value of
+     * coefficient at i-th dimension.
+     *
+     * @param coefficient The model coefficient.
+     * @param reg The reg param.
+     * @param elasticNet The elasticNet param.
+     * @param learningRate The learningRate param.
+     * @return The loss introduced by regularization.
+     */
+    public static double regularize(
+            DenseVector coefficient,
+            final double reg,
+            final double elasticNet,
+            final double learningRate) {
+
+        if (Double.compare(reg, 0) == 0) {
+            return 0;
+        } else {
+            if (Double.compare(elasticNet, 0) == 0) {
+                // Only L2 regularization.
+                double loss = reg / 2 * BLAS.norm2(coefficient);
+                BLAS.scal(1 - learningRate * reg, coefficient);
+                return loss;
+            } else if (Double.compare(elasticNet, 1) == 0) {
+                // Only L1 regularization.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    if (Double.compare(coefficientArray[i], 0) == 0) {
+                        continue;
+                    }
+                    loss += elasticNet * reg * Math.signum(coefficientArray[i]);
+                    coefficientArray[i] -=
+                            learningRate * elasticNet * reg * Math.signum(coefficientArray[i]);
+                }
+                return loss;
+            } else {
+                // Both L1 and L2 are not zero.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    loss +=

Review Comment:
   Good catch. Adding the check `Double.compare(coefficientArray[i], 0) == 0` or not are both correct for two cases. It is added in the case where `elasticNet==1` for efficiency.
   
   In the case where `elasticNet==1` (only L1 reguarlization), we do not need to update the non-zero dimension of the model parameters. Thus we could do the check to avoid computing the loss and the model update.
   
   In the case when elasticNet is nor zero or one (i.e., we have both L1 and L2 regularization), we need to update the model at each dimension no matter the model parameter at that dimension is zero or not. So we removed the check there.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865598619


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.
+         *
+         * @return The loss sum.
+         */
+        private double getLoss() {
+            return feedbackArray[coeffiDim + 1];
+        }
+
+        /**
+         * Sets the loss sum of the processed elements.
+         *
+         * @param loss The loss sum.
+         */
+        private void setLoss(double loss) {
+            feedbackArray[coeffiDim + 1] = loss;
+        }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<double[]> collector)
+                throws Exception {
+            if (epochWatermark == 0) {
+                coefficient = new DenseVector(feedbackArray);
+                coeffiDim = coefficient.size();
+                feedbackArray = new double[coefficient.size() + 2];
+            } else {
+                if (getWeightSum() > 0) {
+                    BLAS.axpy(
+                            -params.learningRate / getWeightSum(),
+                            new DenseVector(feedbackArray),
+                            coefficient,
+                            coeffiDim);
+                    double regLoss =
+                            RegularizationUtils.regularize(
+                                    coefficient,
+                                    params.reg,
+                                    params.elasticNet,
+                                    params.learningRate);
+                    setLoss(getLoss() + regLoss);
+                }
+            }
+
+            if (trainData == null) {
+                trainData = IteratorUtils.toList(trainDataState.get().iterator());
+            }
+
+            // TODO: supports efficient shuffle of training set on each partition.
+            if (trainData.size() > 0) {
+                List<LabeledPointWithWeight> miniBatchData =
+                        trainData.subList(
+                                nextBatchOffset,
+                                Math.min(nextBatchOffset + localBatchSize, trainData.size()));

Review Comment:
   yes. It is consistent with existing libraries. [1] [2]
   
   [1] https://www.tensorflow.org/api_docs/python/tf/data/Dataset (seach "last batch")
   [2] https://stackoverflow.com/questions/59553964/incomplete-last-batch-in-tensorflow



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858420942


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/GeneralLinearAlgo.java:
##########
@@ -0,0 +1,414 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;

Review Comment:
   I have renamed package `glm` to `linear`.
   
   I have also marked all classes as `Internal` in this package and checked visibility of all methods in this package.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865520528


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -67,6 +72,28 @@ public static <IN, OUT> DataStream<OUT> mapPartition(
                 .setParallelism(input.getParallelism());
     }
 
+    /**
+     * Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most
+     * one stream record and its parallelism is one.
+     *
+     * @param input The input data stream.
+     * @param func The user defined reduce function.
+     * @param <T> The class type of the input.
+     * @return The result data stream.
+     */
+    public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) {

Review Comment:
   As I know, `reduce` is a widely used function name for applying a reduce function to a collection of elements. Examples could be Spark `RDD#reduce`[1], Flink `DataSet#reduce`[2].
   
   
   [1] https://sparkbyexamples.com/apache-spark-rdd/spark-rdd-reduce-function-example/
   [2] https://nightlies.apache.org/flink/flink-docs-master/docs/dev/dataset/transformations/#reduce-on-full-dataset



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865596624


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.
+         *
+         * @return The loss sum.
+         */
+        private double getLoss() {
+            return feedbackArray[coeffiDim + 1];
+        }
+
+        /**
+         * Sets the loss sum of the processed elements.
+         *
+         * @param loss The loss sum.
+         */
+        private void setLoss(double loss) {
+            feedbackArray[coeffiDim + 1] = loss;
+        }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<double[]> collector)
+                throws Exception {
+            if (epochWatermark == 0) {
+                coefficient = new DenseVector(feedbackArray);
+                coeffiDim = coefficient.size();
+                feedbackArray = new double[coefficient.size() + 2];
+            } else {
+                if (getWeightSum() > 0) {
+                    BLAS.axpy(
+                            -params.learningRate / getWeightSum(),
+                            new DenseVector(feedbackArray),
+                            coefficient,
+                            coeffiDim);
+                    double regLoss =
+                            RegularizationUtils.regularize(
+                                    coefficient,
+                                    params.reg,
+                                    params.elasticNet,
+                                    params.learningRate);
+                    setLoss(getLoss() + regLoss);
+                }
+            }
+
+            if (trainData == null) {
+                trainData = IteratorUtils.toList(trainDataState.get().iterator());
+            }
+
+            // TODO: supports efficient shuffle of training set on each partition.
+            if (trainData.size() > 0) {
+                List<LabeledPointWithWeight> miniBatchData =
+                        trainData.subList(
+                                nextBatchOffset,
+                                Math.min(nextBatchOffset + localBatchSize, trainData.size()));
+                nextBatchOffset += localBatchSize;
+                nextBatchOffset = nextBatchOffset >= trainData.size() ? 0 : nextBatchOffset;

Review Comment:
   The previous approach is rather a toy implementation and is not treating samples on different partitions uniformly --- Suppose there are two workers and the globalBatchSize is 32. We further assume that there are 10 and 20 elements on `worker1` and `worker2`, respectively. Then in the previous approach, we need to pick 16 elements on `worker1` and 16 elements on `worker2`, in which case each element on `worker1` has larger probability to be picked than those on `worker2`.
   
   Spark uses `mini-batch fraction` to sample training data for each iteration and we have discussed it here [1]
   
   [1] https://github.com/apache/flink-ml/pull/28#discussion_r753641174
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865539013


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java:
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/**
+ * A utility class for algorithms that need to handle regularization. The regularization term is
+ * defined as:
+ *
+ * <p>elasticNet * reg * norm1(coefficient) + (1 - elasticNet) * (reg/2) * (norm2(coefficient))^2
+ *
+ * <p>See https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html.
+ */
+@Internal
+class RegularizationUtils {
+
+    /**
+     * Regularize the model coefficient. The gradient of each dimension could be computed as:
+     * {elasticNet * reg * Math.sign(c_i) + (1 - elasticNet) * reg * c_i}. Here c_i is the value of
+     * coefficient at i-th dimension.
+     *
+     * @param coefficient The model coefficient.
+     * @param reg The reg param.
+     * @param elasticNet The elasticNet param.
+     * @param learningRate The learningRate param.
+     * @return The loss introduced by regularization.
+     */
+    public static double regularize(
+            DenseVector coefficient,
+            final double reg,
+            final double elasticNet,
+            final double learningRate) {
+
+        if (Double.compare(reg, 0) == 0) {
+            return 0;
+        } else {
+            if (Double.compare(elasticNet, 0) == 0) {
+                // Only L2 regularization.
+                double loss = reg / 2 * BLAS.norm2(coefficient);
+                BLAS.scal(1 - learningRate * reg, coefficient);
+                return loss;
+            } else if (Double.compare(elasticNet, 1) == 0) {
+                // Only L1 regularization.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    if (Double.compare(coefficientArray[i], 0) == 0) {
+                        continue;
+                    }
+                    loss += elasticNet * reg * Math.signum(coefficientArray[i]);
+                    coefficientArray[i] -=
+                            learningRate * elasticNet * reg * Math.signum(coefficientArray[i]);
+                }
+                return loss;
+            } else {
+                // Both L1 and L2 are not zero.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    loss +=

Review Comment:
   Good catch. Adding the check `Double.compare(coefficientArray[i], 0) == 0` or not are both correct for two cases. It is added in the case where `elasticNet==1` for efficiency.
   
   In the case where `elasticNet==1` (only L1 reguarlization), we do not need to update the non-zero dimension of the model parameters. Thus we could do the check to avoid computing the loss and the model update for those `zero` dimensions.
   
   In the case when elasticNet is nor zero or one (i.e., we have both L1 and L2 regularization), we need to update the model at each dimension no matter the model parameter at that dimension is zero or not. So we removed the check there.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865904573


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {

Review Comment:
   Ah I see.
   
   Do you think it would be a bit more self-explanation to name the last two fields of `feedbackArray` as `totalWeight` and `totalLoss`?
   
   Since it is an implementation detail, the existing names also work for me.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865911454


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.

Review Comment:
   It would be great if this method can also support unbounded `trainData`.
   
   Note that currently this method supports only bounded `trainData`. How about we add a TODO here so that we can make sure to either fix this method or document it properly before the next release?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865962589


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.
+     *
+     * @param bcInitModel The broadcast init model. Note that each task contains one DenseVector as
+     *     the model data and the model data on each task are exactly the same.
+     * @param trainData The training data.
+     * @param lossFunc The loss function to optimize.
+     * @return The fitted model. Note that the parallelism of the returned stream is one.
+     */
+    DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,

Review Comment:
   `initial` seems unnecessary because even if we remove this word, we can still correctly explain the semantics of this Java doc more concisely. For example, the Java doc of this method could be `Optimizes the given loss function using the given model data and the training data`.
   
   I am also OK to keep the existing name `initialModelData`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860608492


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java:
##########
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModel;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModelData;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+/** Tests {@link LinearRegression} and {@link LinearRegressionModel}. */
+public class LinearRegressionTest {
+
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private StreamExecutionEnvironment env;
+
+    private StreamTableEnvironment tEnv;
+
+    private static final List<Row> trainData =
+            Arrays.asList(
+                    Row.of(Vectors.dense(2, 1), 4.0, 1.0),
+                    Row.of(Vectors.dense(3, 2), 7.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 4), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 2), 6.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(1, 2), 5.0, 1.0),
+                    Row.of(Vectors.dense(5, 3), 11.0, 1.0));
+
+    private static final double[] expectedCoefficient = new double[] {1.0, 2.0};
+
+    private static final double TOLERANCE = 1e-7;
+
+    private Table trainDataTable;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Collections.shuffle(trainData);
+        trainDataTable =
+                tEnv.fromDataStream(
+                        env.fromCollection(
+                                trainData,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class),
+                                            Types.DOUBLE,
+                                            Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label", "weight"})));
+    }
+
+    @SuppressWarnings("unchecked")
+    private void verifyPredictionResult(
+            Table output, String labelCol, String weightCol, String predictionCol)
+            throws Exception {
+        List<Row> predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        double lossSum = 0;
+        for (Row predictionRow : predResult) {
+            double label = (double) predictionRow.getField(labelCol);
+            double prediction = (double) predictionRow.getField(predictionCol);
+            double weight = (double) predictionRow.getField(weightCol);
+            lossSum += weight * Math.pow(label - prediction, 2);
+        }
+        assertTrue(lossSum < 1.0);

Review Comment:
   It is a regression task, so the prediction is NOT a label. It is not possible to assert that each prediction is correct as a label.
   
   I have refined the function and assert that the average loss is smaller than 0.1.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860633366


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -67,6 +72,28 @@ public static <IN, OUT> DataStream<OUT> mapPartition(
                 .setParallelism(input.getParallelism());
     }
 
+    /**
+     * Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most
+     * one stream record and its parallelism is one.
+     *
+     * @param input The input data stream.
+     * @param func The user defined reduce function.
+     * @param <T> The class type of the input.
+     * @return The result data stream.
+     */
+    public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) {
+        DataStream<T> partialReducedStream =
+                input.transform("partialReduce", input.getType(), new ReduceOperator<>(func))
+                        .setParallelism(input.getParallelism());

Review Comment:
   So far as I can see from the `JobGraph` of LinearRegressionTest.testFitAndPredict`, `reduce` and `partialReduce` belongs to different `JobVertex` and not chained together yet.



##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -67,6 +72,28 @@ public static <IN, OUT> DataStream<OUT> mapPartition(
                 .setParallelism(input.getParallelism());
     }
 
+    /**
+     * Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most
+     * one stream record and its parallelism is one.
+     *
+     * @param input The input data stream.
+     * @param func The user defined reduce function.
+     * @param <T> The class type of the input.
+     * @return The result data stream.
+     */
+    public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) {
+        DataStream<T> partialReducedStream =
+                input.transform("partialReduce", input.getType(), new ReduceOperator<>(func))
+                        .setParallelism(input.getParallelism());

Review Comment:
   So far as I can see from the `JobGraph` of `LinearRegressionTest.testFitAndPredict`, `reduce` and `partialReduce` belongs to different `JobVertex` and not chained together yet.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r861482738


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -67,6 +72,28 @@ public static <IN, OUT> DataStream<OUT> mapPartition(
                 .setParallelism(input.getParallelism());
     }
 
+    /**
+     * Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most
+     * one stream record and its parallelism is one.
+     *
+     * @param input The input data stream.
+     * @param func The user defined reduce function.
+     * @param <T> The class type of the input.
+     * @return The result data stream.
+     */
+    public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) {

Review Comment:
   Would it make the method more self-explanatory by changing the method name to `reduceAll`? This would make the name more consistent with `DataStream::windowAll` etc.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {

Review Comment:
   Should it be `getWeightedSum()`?
   
   Could we update Java doc and variables used in this file so that we mention `weighted sum` instead of `weight sum`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.
+         *
+         * @return The loss sum.
+         */
+        private double getLoss() {
+            return feedbackArray[coeffiDim + 1];
+        }
+
+        /**
+         * Sets the loss sum of the processed elements.
+         *
+         * @param loss The loss sum.
+         */
+        private void setLoss(double loss) {
+            feedbackArray[coeffiDim + 1] = loss;
+        }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<double[]> collector)
+                throws Exception {
+            if (epochWatermark == 0) {
+                coefficient = new DenseVector(feedbackArray);
+                coeffiDim = coefficient.size();
+                feedbackArray = new double[coefficient.size() + 2];
+            } else {
+                if (getWeightSum() > 0) {
+                    BLAS.axpy(
+                            -params.learningRate / getWeightSum(),
+                            new DenseVector(feedbackArray),
+                            coefficient,
+                            coeffiDim);
+                    double regLoss =
+                            RegularizationUtils.regularize(
+                                    coefficient,
+                                    params.reg,
+                                    params.elasticNet,
+                                    params.learningRate);
+                    setLoss(getLoss() + regLoss);
+                }
+            }
+
+            if (trainData == null) {
+                trainData = IteratorUtils.toList(trainDataState.get().iterator());
+            }
+
+            // TODO: supports efficient shuffle of training set on each partition.
+            if (trainData.size() > 0) {
+                List<LabeledPointWithWeight> miniBatchData =
+                        trainData.subList(
+                                nextBatchOffset,
+                                Math.min(nextBatchOffset + localBatchSize, trainData.size()));
+                nextBatchOffset += localBatchSize;
+                nextBatchOffset = nextBatchOffset >= trainData.size() ? 0 : nextBatchOffset;

Review Comment:
   In the previous implementation of `LogisticRegress`, each task randomly samples a local batch in each round. This PR changes this logic so that each task selects its local batch in a deterministic manner.
   
   Is the new approach strictly better than the other in terms of convergence rate and the converged accuracy? Which approach does Spark ML use for LogisticRegression and LinearRegression?
   



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java:
##########
@@ -186,6 +185,7 @@ public void testParam() {
         assertEquals(logisticRegression.getLearningRate(), 0.5, TOLERANCE);
         assertEquals(logisticRegression.getGlobalBatchSize(), 1000);
         assertEquals(logisticRegression.getReg(), 0.1, TOLERANCE);
+        assertEquals(logisticRegression.getElasticNet(), 0.5, TOLERANCE);

Review Comment:
   Now that we added the elasticNet param for `LogisticRegression`, do we need to test the case where `elasticNet` is different from its default value?
   
   According to the Java doc of this parameter, it seems that we might need to test three cases, e.g. with elasticNet = 0, 0.5, 1.
   
   We can add this as TODO and do it in a separate PR if it involves non-trivial work.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.
+     *
+     * @param bcInitModel The broadcast init model. Note that each task contains one DenseVector as

Review Comment:
   nits: `broadcast init model` -> `broadcasted init model`
   
   Given that the Java doc of `DataStreamUtils::allReduceSum` says `each partition contains one double array`, would it be more consistent to use `partition` instead of `task` here?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.
+     *
+     * @param bcInitModel The broadcast init model. Note that each task contains one DenseVector as
+     *     the model data and the model data on each task are exactly the same.
+     * @param trainData The training data.
+     * @param lossFunc The loss function to optimize.
+     * @return The fitted model. Note that the parallelism of the returned stream is one.
+     */
+    DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,

Review Comment:
   `Model` might be confused with the model class in Flink ML. How about we consistently use `modelData` to refer to the model data in the Java doc and function signature?
   
   I am not sure we need to specifically mention `initial` in the function signature, since this word does not seem necessary to understand the behavior of this function.
   
   Instead of calling the `broadcast()` outside this method, how about we invoke the broadcast inside this method, so that both the caller code and the method signature could be simpler?
   



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.

Review Comment:
   nits: `Optimize the given loss` -> `Optimizes the given loss` to be consistent with other comments.
   
   Given that some algorithms in Flink ML could handle unbounded model data, would it useful to explicitly mention that `trainData` must be bounded?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java:
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/**
+ * A utility class for algorithms that need to handle regularization. The regularization term is
+ * defined as:
+ *
+ * <p>elasticNet * reg * norm1(coefficient) + (1 - elasticNet) * (reg/2) * (norm2(coefficient))^2
+ *
+ * <p>See https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html.
+ */
+@Internal
+class RegularizationUtils {
+
+    /**
+     * Regularize the model coefficient. The gradient of each dimension could be computed as:
+     * {elasticNet * reg * Math.sign(c_i) + (1 - elasticNet) * reg * c_i}. Here c_i is the value of
+     * coefficient at i-th dimension.
+     *
+     * @param coefficient The model coefficient.
+     * @param reg The reg param.
+     * @param elasticNet The elasticNet param.
+     * @param learningRate The learningRate param.
+     * @return The loss introduced by regularization.
+     */
+    public static double regularize(
+            DenseVector coefficient,
+            final double reg,
+            final double elasticNet,
+            final double learningRate) {
+
+        if (Double.compare(reg, 0) == 0) {
+            return 0;
+        } else {
+            if (Double.compare(elasticNet, 0) == 0) {
+                // Only L2 regularization.
+                double loss = reg / 2 * BLAS.norm2(coefficient);
+                BLAS.scal(1 - learningRate * reg, coefficient);
+                return loss;
+            } else if (Double.compare(elasticNet, 1) == 0) {
+                // Only L1 regularization.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    if (Double.compare(coefficientArray[i], 0) == 0) {
+                        continue;
+                    }
+                    loss += elasticNet * reg * Math.signum(coefficientArray[i]);
+                    coefficientArray[i] -=
+                            learningRate * elasticNet * reg * Math.signum(coefficientArray[i]);
+                }
+                return loss;
+            } else {
+                // Both L1 and L2 are not zero.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    loss +=
+                            elasticNet * reg * Math.signum(coefficientArray[i])
+                                    + (1 - elasticNet)
+                                            * (reg / 2)
+                                            * coefficientArray[i]
+                                            * coefficientArray[i];
+                    coefficientArray[i] =

Review Comment:
   nits: maybe change to `coefficientArray[i] -= ...` for consistency with the above code?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;

Review Comment:
   nits: given that we already use `coefficient` instead of `coeffi` above, how about renaming this variable as `coefficientDim` for consistency?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.

Review Comment:
   nits: `model update, weight and loss` -> `model update, weight, and loss`



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java:
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/**
+ * A utility class for algorithms that need to handle regularization. The regularization term is
+ * defined as:
+ *
+ * <p>elasticNet * reg * norm1(coefficient) + (1 - elasticNet) * (reg/2) * (norm2(coefficient))^2
+ *
+ * <p>See https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html.
+ */
+@Internal
+class RegularizationUtils {
+
+    /**
+     * Regularize the model coefficient. The gradient of each dimension could be computed as:
+     * {elasticNet * reg * Math.sign(c_i) + (1 - elasticNet) * reg * c_i}. Here c_i is the value of
+     * coefficient at i-th dimension.
+     *
+     * @param coefficient The model coefficient.
+     * @param reg The reg param.
+     * @param elasticNet The elasticNet param.
+     * @param learningRate The learningRate param.
+     * @return The loss introduced by regularization.
+     */
+    public static double regularize(
+            DenseVector coefficient,
+            final double reg,
+            final double elasticNet,
+            final double learningRate) {
+
+        if (Double.compare(reg, 0) == 0) {
+            return 0;
+        } else {
+            if (Double.compare(elasticNet, 0) == 0) {
+                // Only L2 regularization.
+                double loss = reg / 2 * BLAS.norm2(coefficient);
+                BLAS.scal(1 - learningRate * reg, coefficient);
+                return loss;
+            } else if (Double.compare(elasticNet, 1) == 0) {
+                // Only L1 regularization.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    if (Double.compare(coefficientArray[i], 0) == 0) {
+                        continue;
+                    }
+                    loss += elasticNet * reg * Math.signum(coefficientArray[i]);
+                    coefficientArray[i] -=
+                            learningRate * elasticNet * reg * Math.signum(coefficientArray[i]);
+                }
+                return loss;
+            } else {
+                // Both L1 and L2 are not zero.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    loss +=

Review Comment:
   Given that we checked `Double.compare(coefficientArray[i], 0) == 0` in the case where `elasticNet == 1`, do we also need to do the same check here?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java:
##########
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/** The loss function for binary logistic loss. See {@link LogisticRegression} for example. */
+@Internal
+public class BinaryLogisticLoss implements LossFunc {

Review Comment:
   This is pretty nice and clean.
   
   Since this is part of infra and might be used by multiple algorithms in the future, do we need to add unit tests for `BinaryLogisticLoss` and `LeastSquareLoss`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;

Review Comment:
   nits: The code above has line break between variable declarations. Could we make the variable declaration code follow the same style?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;

Review Comment:
   nits: We typically declare final variables before non-final variables. Could we update the code style as appropriate?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.

Review Comment:
   nits: Should it be `loss of ...` or `total loss of ...`?
   
   It might be simpler to just remove the Java doc here as this is a private function that its logic is pretty simple.
   
   Same for `setLoss(...)`, `setWeightSum(...)` and `getWeightSum()`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.

Review Comment:
   nits: `is consisted of` -> `consists of`
   
   Would it be better to change `{modelUpdate, weightSum, lossSum}` to `[modelUpdate, weightSum, lossSum]` to indicate that this is an array?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java:
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression.linearregression;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which predicts data using the model data computed by {@link LinearRegression}. */
+public class LinearRegressionModel
+        implements Model<LinearRegressionModel>,
+                LinearRegressionModelParams<LinearRegressionModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    private Table modelDataTable;
+
+    public LinearRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {

Review Comment:
   nits: Currently the order of save/load/fit/getParamMap is not consistent across algorithms. I am thinking it might be nice to make them consistent. We can do it in a separate PR.
   
   Since `fit()` is the most important method for each Estimator, how about we order them as fit, save, load, getParamMap?
   
   And for Model, we can order them as transform, setModelData, getModelData, save, load, getParamMap.
   
   What do you think?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.
+         *
+         * @return The loss sum.
+         */
+        private double getLoss() {
+            return feedbackArray[coeffiDim + 1];
+        }
+
+        /**
+         * Sets the loss sum of the processed elements.
+         *
+         * @param loss The loss sum.
+         */
+        private void setLoss(double loss) {
+            feedbackArray[coeffiDim + 1] = loss;
+        }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<double[]> collector)
+                throws Exception {
+            if (epochWatermark == 0) {
+                coefficient = new DenseVector(feedbackArray);
+                coeffiDim = coefficient.size();
+                feedbackArray = new double[coefficient.size() + 2];
+            } else {
+                if (getWeightSum() > 0) {
+                    BLAS.axpy(
+                            -params.learningRate / getWeightSum(),
+                            new DenseVector(feedbackArray),
+                            coefficient,
+                            coeffiDim);
+                    double regLoss =
+                            RegularizationUtils.regularize(
+                                    coefficient,
+                                    params.reg,
+                                    params.elasticNet,
+                                    params.learningRate);
+                    setLoss(getLoss() + regLoss);
+                }
+            }
+
+            if (trainData == null) {
+                trainData = IteratorUtils.toList(trainDataState.get().iterator());
+            }
+
+            // TODO: supports efficient shuffle of training set on each partition.
+            if (trainData.size() > 0) {
+                List<LabeledPointWithWeight> miniBatchData =
+                        trainData.subList(
+                                nextBatchOffset,
+                                Math.min(nextBatchOffset + localBatchSize, trainData.size()));

Review Comment:
   It seems that the we could have `miniBatchData.size() < localBatchSize` even if `trainData.size() > localBatchSize`. Is it expected?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860681944


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java:
##########
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModel;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModelData;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+/** Tests {@link LinearRegression} and {@link LinearRegressionModel}. */
+public class LinearRegressionTest {
+
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private StreamExecutionEnvironment env;
+
+    private StreamTableEnvironment tEnv;
+
+    private static final List<Row> trainData =
+            Arrays.asList(
+                    Row.of(Vectors.dense(2, 1), 4.0, 1.0),
+                    Row.of(Vectors.dense(3, 2), 7.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 4), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 2), 6.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(1, 2), 5.0, 1.0),
+                    Row.of(Vectors.dense(5, 3), 11.0, 1.0));
+
+    private static final double[] expectedCoefficient = new double[] {1.0, 2.0};
+
+    private static final double TOLERANCE = 1e-7;
+
+    private Table trainDataTable;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Collections.shuffle(trainData);
+        trainDataTable =
+                tEnv.fromDataStream(
+                        env.fromCollection(
+                                trainData,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class),
+                                            Types.DOUBLE,
+                                            Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label", "weight"})));
+    }
+
+    @SuppressWarnings("unchecked")
+    private void verifyPredictionResult(
+            Table output, String labelCol, String weightCol, String predictionCol)
+            throws Exception {
+        List<Row> predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        double lossSum = 0;
+        for (Row predictionRow : predResult) {
+            double label = (double) predictionRow.getField(labelCol);
+            double prediction = (double) predictionRow.getField(predictionCol);
+            double weight = (double) predictionRow.getField(weightCol);
+            lossSum += weight * Math.pow(label - prediction, 2);
+        }
+        assertTrue(lossSum < 1.0);

Review Comment:
   Sounds good. I have updated the PR and make assertions on the prediction result.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865596624


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.
+         *
+         * @return The loss sum.
+         */
+        private double getLoss() {
+            return feedbackArray[coeffiDim + 1];
+        }
+
+        /**
+         * Sets the loss sum of the processed elements.
+         *
+         * @param loss The loss sum.
+         */
+        private void setLoss(double loss) {
+            feedbackArray[coeffiDim + 1] = loss;
+        }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<double[]> collector)
+                throws Exception {
+            if (epochWatermark == 0) {
+                coefficient = new DenseVector(feedbackArray);
+                coeffiDim = coefficient.size();
+                feedbackArray = new double[coefficient.size() + 2];
+            } else {
+                if (getWeightSum() > 0) {
+                    BLAS.axpy(
+                            -params.learningRate / getWeightSum(),
+                            new DenseVector(feedbackArray),
+                            coefficient,
+                            coeffiDim);
+                    double regLoss =
+                            RegularizationUtils.regularize(
+                                    coefficient,
+                                    params.reg,
+                                    params.elasticNet,
+                                    params.learningRate);
+                    setLoss(getLoss() + regLoss);
+                }
+            }
+
+            if (trainData == null) {
+                trainData = IteratorUtils.toList(trainDataState.get().iterator());
+            }
+
+            // TODO: supports efficient shuffle of training set on each partition.
+            if (trainData.size() > 0) {
+                List<LabeledPointWithWeight> miniBatchData =
+                        trainData.subList(
+                                nextBatchOffset,
+                                Math.min(nextBatchOffset + localBatchSize, trainData.size()));
+                nextBatchOffset += localBatchSize;
+                nextBatchOffset = nextBatchOffset >= trainData.size() ? 0 : nextBatchOffset;

Review Comment:
   The previous approach is rather a toy implementation and is not treating samples on different partitions uniformly --- 
   
   Suppose there are two workers and the globalBatchSize is 32. We further assume that there are 10 and 20 elements on `worker1` and `worker2`, respectively. Then in the previous approach, we need to pick 16 elements on `worker1` and 16 elements on `worker2`, in which case each element on `worker1` has larger probability to be picked than those on `worker2`.
   
   Spark uses `mini-batch fraction` to sample training data for each iteration and we have discussed it here [1]
   
   [1] https://github.com/apache/flink-ml/pull/28#discussion_r753641174
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865596624


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.
+         *
+         * @return The loss sum.
+         */
+        private double getLoss() {
+            return feedbackArray[coeffiDim + 1];
+        }
+
+        /**
+         * Sets the loss sum of the processed elements.
+         *
+         * @param loss The loss sum.
+         */
+        private void setLoss(double loss) {
+            feedbackArray[coeffiDim + 1] = loss;
+        }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<double[]> collector)
+                throws Exception {
+            if (epochWatermark == 0) {
+                coefficient = new DenseVector(feedbackArray);
+                coeffiDim = coefficient.size();
+                feedbackArray = new double[coefficient.size() + 2];
+            } else {
+                if (getWeightSum() > 0) {
+                    BLAS.axpy(
+                            -params.learningRate / getWeightSum(),
+                            new DenseVector(feedbackArray),
+                            coefficient,
+                            coeffiDim);
+                    double regLoss =
+                            RegularizationUtils.regularize(
+                                    coefficient,
+                                    params.reg,
+                                    params.elasticNet,
+                                    params.learningRate);
+                    setLoss(getLoss() + regLoss);
+                }
+            }
+
+            if (trainData == null) {
+                trainData = IteratorUtils.toList(trainDataState.get().iterator());
+            }
+
+            // TODO: supports efficient shuffle of training set on each partition.
+            if (trainData.size() > 0) {
+                List<LabeledPointWithWeight> miniBatchData =
+                        trainData.subList(
+                                nextBatchOffset,
+                                Math.min(nextBatchOffset + localBatchSize, trainData.size()));
+                nextBatchOffset += localBatchSize;
+                nextBatchOffset = nextBatchOffset >= trainData.size() ? 0 : nextBatchOffset;

Review Comment:
   The previous approach is rather a toy implementation and needs to randomly access all of the data points, which could be expensive when the data is cached on disk.
   
   Now we replace the old one with sequential read and shuffle from time to time (a TODO). It is common practice in machine learning training [1].
   
   Spark uses `mini-batch fraction` to sample training data for each iteration and we have discussed it here [2]
   
   [1] https://www.tensorflow.org/api_docs/python/tf/data/Dataset (search "make sure to call shuffle after calling cache.")
   [2] https://github.com/apache/flink-ml/pull/28#discussion_r753641174
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865907433


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.
+         *
+         * @return The loss sum.
+         */
+        private double getLoss() {
+            return feedbackArray[coeffiDim + 1];
+        }
+
+        /**
+         * Sets the loss sum of the processed elements.
+         *
+         * @param loss The loss sum.
+         */
+        private void setLoss(double loss) {
+            feedbackArray[coeffiDim + 1] = loss;
+        }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<double[]> collector)
+                throws Exception {
+            if (epochWatermark == 0) {
+                coefficient = new DenseVector(feedbackArray);
+                coeffiDim = coefficient.size();
+                feedbackArray = new double[coefficient.size() + 2];
+            } else {
+                if (getWeightSum() > 0) {
+                    BLAS.axpy(
+                            -params.learningRate / getWeightSum(),
+                            new DenseVector(feedbackArray),
+                            coefficient,
+                            coeffiDim);
+                    double regLoss =
+                            RegularizationUtils.regularize(
+                                    coefficient,
+                                    params.reg,
+                                    params.elasticNet,
+                                    params.learningRate);
+                    setLoss(getLoss() + regLoss);
+                }
+            }
+
+            if (trainData == null) {
+                trainData = IteratorUtils.toList(trainDataState.get().iterator());
+            }
+
+            // TODO: supports efficient shuffle of training set on each partition.
+            if (trainData.size() > 0) {
+                List<LabeledPointWithWeight> miniBatchData =
+                        trainData.subList(
+                                nextBatchOffset,
+                                Math.min(nextBatchOffset + localBatchSize, trainData.size()));
+                nextBatchOffset += localBatchSize;
+                nextBatchOffset = nextBatchOffset >= trainData.size() ? 0 : nextBatchOffset;

Review Comment:
   Sounds good. Thanks for the explanation!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865911454


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.

Review Comment:
   It would be great if this method can also support unbounded `trainData`.
   
   Note that this method supports only bounded `trainData` currently. How about we add a TODO here so that we can make sure to either fix this method or document it properly before the next release?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858415540


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/LocalTrainer.java:
##########
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+
+import java.util.Iterator;
+
+/**
+ * A local trainer is a trainer that uses a batch of training data to compute a update of the model
+ * locally.
+ *
+ * @param <T> Class type of training data.
+ */
+public abstract class LocalTrainer<T> implements CheckpointedFunction {

Review Comment:
   Thanks for the insightful comments.
   
   > move CacheDataAndDoTrain to an independent class
   
   It is an dependent class in the old PR. Do you mean to merge it with LocalTrainer and WithRegularization?
   
   > merge LocalTrainer and WithRegularization and remove methods like getReg, as now they are only used internally
   
   It seems hard to merge these two class, because the logic of `WithRegularization` is supposed to be used in `LocalTrainer#updateModel`. I have renamed `WithRegularization` as `RegularizationUtils` and removed methods like `withReg`. What do you think?
   
   > rename the merged class. A name like LocalTrainer might be too general to be associated with linear algorithms.
   
   I have renamed the class to `LocalLinearTrainer`. I still think it is not a very good name. We could probably discuss more on the naming.
   
   > merge CacheDataAndDoTrain, LocalTrainer and WithRegularization.
   
   I did not merge `CacheDataAndDoTrain` with the other two for now for the following reasons:
   - `CacheDataAndDoTrain` is a more a infra and involves distributed concepts, like two input operators.
   - `LocalTrainer` is a more friendly and clean concept for machine learning users, since it only involves local operations.
   
   > change the API and implementation of trainOnBatchData. I have the sense that each time only one data, instead of a batch of data, is enough. I think the API of org.apache.flink.api.common.functions.AggregateFunction is a good reference.
   
   I did not do the change for the following reasons:
   - mini-batch training is a common concept for machine learning.
   - Users may want to do operations before/after mini-batch training.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r866427486


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java:
##########
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.DenseVector;
+
+import java.io.Serializable;
+
+/**
+ * A loss function is to compute the loss and gradient with the given coefficient and training data.
+ */
+@Internal
+public interface LossFunc extends Serializable {
+
+    /**
+     * Computes the loss on the given data point.
+     *
+     * @param dataPoint A training data point.
+     * @param coefficient The model parameters.
+     * @return The loss of the input data.
+     */
+    double computeLoss(LabeledPointWithWeight dataPoint, DenseVector coefficient);
+
+    /**
+     * Computes the gradient on the given data point.

Review Comment:
   Thanks for the comment. The java doc is updated as follows:
   `Computes the gradient on the given data point and adds the computed gradient to the cumGradient.`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860637007


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java:
##########
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModel;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModelData;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+/** Tests {@link LinearRegression} and {@link LinearRegressionModel}. */
+public class LinearRegressionTest {
+
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private StreamExecutionEnvironment env;
+
+    private StreamTableEnvironment tEnv;
+
+    private static final List<Row> trainData =
+            Arrays.asList(
+                    Row.of(Vectors.dense(2, 1), 4.0, 1.0),
+                    Row.of(Vectors.dense(3, 2), 7.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 4), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 2), 6.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(1, 2), 5.0, 1.0),
+                    Row.of(Vectors.dense(5, 3), 11.0, 1.0));
+
+    private static final double[] expectedCoefficient = new double[] {1.0, 2.0};
+
+    private static final double TOLERANCE = 1e-7;
+
+    private Table trainDataTable;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Collections.shuffle(trainData);
+        trainDataTable =
+                tEnv.fromDataStream(
+                        env.fromCollection(
+                                trainData,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class),
+                                            Types.DOUBLE,
+                                            Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label", "weight"})));
+    }
+
+    @SuppressWarnings("unchecked")
+    private void verifyPredictionResult(
+            Table output, String labelCol, String weightCol, String predictionCol)
+            throws Exception {
+        List<Row> predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        double lossSum = 0;
+        for (Row predictionRow : predResult) {
+            double label = (double) predictionRow.getField(labelCol);
+            double prediction = (double) predictionRow.getField(predictionCol);
+            double weight = (double) predictionRow.getField(weightCol);
+            lossSum += weight * Math.pow(label - prediction, 2);
+        }
+        assertTrue(lossSum < 1.0);

Review Comment:
   In `testGetModelData` it has been possible to verify the coefficient value, so I suppose the prediction result computed from the coefficient and the input data should also be predictable.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865585957


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.
+     *
+     * @param bcInitModel The broadcast init model. Note that each task contains one DenseVector as
+     *     the model data and the model data on each task are exactly the same.
+     * @param trainData The training data.
+     * @param lossFunc The loss function to optimize.
+     * @return The fitted model. Note that the parallelism of the returned stream is one.
+     */
+    DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,

Review Comment:
   - Sure.
   - The `initial` means the initial value of the model data (coefficient). Why is it unnecessary?
   - Sure :)
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865589640


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.

Review Comment:
   Thanks for the comment. I have removed the Java doc of thee private methods.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865587651


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {

Review Comment:
   This function returns the sum of the weights of the processed elements. Why should we use `getWeightedSum`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865596624


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.
+         *
+         * @return The loss sum.
+         */
+        private double getLoss() {
+            return feedbackArray[coeffiDim + 1];
+        }
+
+        /**
+         * Sets the loss sum of the processed elements.
+         *
+         * @param loss The loss sum.
+         */
+        private void setLoss(double loss) {
+            feedbackArray[coeffiDim + 1] = loss;
+        }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<double[]> collector)
+                throws Exception {
+            if (epochWatermark == 0) {
+                coefficient = new DenseVector(feedbackArray);
+                coeffiDim = coefficient.size();
+                feedbackArray = new double[coefficient.size() + 2];
+            } else {
+                if (getWeightSum() > 0) {
+                    BLAS.axpy(
+                            -params.learningRate / getWeightSum(),
+                            new DenseVector(feedbackArray),
+                            coefficient,
+                            coeffiDim);
+                    double regLoss =
+                            RegularizationUtils.regularize(
+                                    coefficient,
+                                    params.reg,
+                                    params.elasticNet,
+                                    params.learningRate);
+                    setLoss(getLoss() + regLoss);
+                }
+            }
+
+            if (trainData == null) {
+                trainData = IteratorUtils.toList(trainDataState.get().iterator());
+            }
+
+            // TODO: supports efficient shuffle of training set on each partition.
+            if (trainData.size() > 0) {
+                List<LabeledPointWithWeight> miniBatchData =
+                        trainData.subList(
+                                nextBatchOffset,
+                                Math.min(nextBatchOffset + localBatchSize, trainData.size()));
+                nextBatchOffset += localBatchSize;
+                nextBatchOffset = nextBatchOffset >= trainData.size() ? 0 : nextBatchOffset;

Review Comment:
   The previous approach is rather a toy implementation and needs to randomly access all of the data points, which could be expensive when the data is cached on disk.
   
   Now we replace the old one with sequential read and shuffle from time to time (a TODO). It is a common practice in machine learning training [1].
   
   Spark uses `mini-batch fraction` to sample training data for each iteration and we have discussed it here [2]
   
   [1] https://www.tensorflow.org/api_docs/python/tf/data/Dataset (search "make sure to call shuffle after calling cache.")
   [2] https://github.com/apache/flink-ml/pull/28#discussion_r753641174
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865897261


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java:
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/**
+ * A utility class for algorithms that need to handle regularization. The regularization term is
+ * defined as:
+ *
+ * <p>elasticNet * reg * norm1(coefficient) + (1 - elasticNet) * (reg/2) * (norm2(coefficient))^2
+ *
+ * <p>See https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html.
+ */
+@Internal
+class RegularizationUtils {
+
+    /**
+     * Regularize the model coefficient. The gradient of each dimension could be computed as:
+     * {elasticNet * reg * Math.sign(c_i) + (1 - elasticNet) * reg * c_i}. Here c_i is the value of
+     * coefficient at i-th dimension.
+     *
+     * @param coefficient The model coefficient.
+     * @param reg The reg param.
+     * @param elasticNet The elasticNet param.
+     * @param learningRate The learningRate param.
+     * @return The loss introduced by regularization.
+     */
+    public static double regularize(
+            DenseVector coefficient,
+            final double reg,
+            final double elasticNet,
+            final double learningRate) {
+
+        if (Double.compare(reg, 0) == 0) {
+            return 0;
+        } else {
+            if (Double.compare(elasticNet, 0) == 0) {
+                // Only L2 regularization.
+                double loss = reg / 2 * BLAS.norm2(coefficient);
+                BLAS.scal(1 - learningRate * reg, coefficient);
+                return loss;
+            } else if (Double.compare(elasticNet, 1) == 0) {
+                // Only L1 regularization.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    if (Double.compare(coefficientArray[i], 0) == 0) {
+                        continue;
+                    }
+                    loss += elasticNet * reg * Math.signum(coefficientArray[i]);
+                    coefficientArray[i] -=
+                            learningRate * elasticNet * reg * Math.signum(coefficientArray[i]);
+                }
+                return loss;
+            } else {
+                // Both L1 and L2 are not zero.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    loss +=

Review Comment:
   Sounds good. Thanks for the explanation.
   
   Should we remove the code which is commented out?
   
   ```
   // if (Double.compare(coefficientArray[i], 0) == 0) {
   //    continue;
   // }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865599918


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java:
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression.linearregression;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which predicts data using the model data computed by {@link LinearRegression}. */
+public class LinearRegressionModel
+        implements Model<LinearRegressionModel>,
+                LinearRegressionModelParams<LinearRegressionModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    private Table modelDataTable;
+
+    public LinearRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {

Review Comment:
   I think it is reasonable to make this change. I have updated LinearRegression and LogisticRegression.
   
   For other algorithms, we can do it in another PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865911454


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.

Review Comment:
   Yep. That works for me.
   
   If you plan to update `optimize(...)` to support unbounded inputs in the future, could we add a TODO here? Note that the current implementation can only support bounded data. And we probably need to make sure this function throws proper exception, if we have not made it support unbounded inputs before the release.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858415540


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/LocalTrainer.java:
##########
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+
+import java.util.Iterator;
+
+/**
+ * A local trainer is a trainer that uses a batch of training data to compute a update of the model
+ * locally.
+ *
+ * @param <T> Class type of training data.
+ */
+public abstract class LocalTrainer<T> implements CheckpointedFunction {

Review Comment:
   Thanks for the insightful comments. We probably need more discussions.
   
   > move CacheDataAndDoTrain to an independent class
   
   It is an dependent class in the old PR. Do you mean to merge it with LocalTrainer and WithRegularization?
   
   > merge LocalTrainer and WithRegularization and remove methods like getReg, as now they are only used internally
   
   It seems hard to merge these two class, because the logic of `WithRegularization` is supposed to be used in `LocalTrainer#updateModel`. I have renamed `WithRegularization` as `RegularizationUtils` and removed methods like `withReg`. What do you think?
   
   > rename the merged class. A name like LocalTrainer might be too general to be associated with linear algorithms.
   
   I have renamed the class to `LocalLinearTrainer`. I still think it is not a very good name. We could probably discuss more on the naming.
   
   > merge CacheDataAndDoTrain, LocalTrainer and WithRegularization.
   
   I did not merge `CacheDataAndDoTrain` with the other two for now for the following reasons:
   - `CacheDataAndDoTrain` is a more a infra and involves distributed concepts, like two input operators.
   - `LocalTrainer` is a more friendly and clean concept for machine learning users, since it only involves local operations.
   
   > change the API and implementation of trainOnBatchData. I have the sense that each time only one data, instead of a batch of data, is enough. I think the API of org.apache.flink.api.common.functions.AggregateFunction is a good reference.
   
   I did not do the change for the following reasons:
   - mini-batch training is a common concept for machine learning.
   - Users may want to do operations before/after mini-batch training.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858415540


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/LocalTrainer.java:
##########
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+
+import java.util.Iterator;
+
+/**
+ * A local trainer is a trainer that uses a batch of training data to compute a update of the model
+ * locally.
+ *
+ * @param <T> Class type of training data.
+ */
+public abstract class LocalTrainer<T> implements CheckpointedFunction {

Review Comment:
   > move CacheDataAndDoTrain to an independent class
   It is an dependent class in the old PR. Do you mean to merge it with LocalTrainer and WithRegularization?
   
   > merge LocalTrainer and WithRegularization and remove methods like getReg, as now they are only used internally
   It seems hard to merge these two class, because the logic of `WithRegularization` is supposed to be used in `LocalTrainer#updateModel`. I have renamed `WithRegularization` as `RegularizationUtils` and removed methods like `withReg`. What do you think?
   
   > rename the merged class. A name like LocalTrainer might be too general to be associated with linear algorithms.
   I have renamed the class to `LocalLinearTrainer`. I still think it is not a very good name. We could probably discuss more on the naming.
   
   > merge CacheDataAndDoTrain, LocalTrainer and WithRegularization.
   I did not merge `CacheDataAndDoTrain` with the other two for now for the following reasons:
   - `CacheDataAndDoTrain` is a more a infra and involves distributed concepts, like two input operators.
   - `LocalTrainer` is a more friendly and clean concept for machine learning users, since it only involves local operations.
   
   > change the API and implementation of trainOnBatchData. I have the sense that each time only one data, instead of a batch of data, is enough. I think the API of org.apache.flink.api.common.functions.AggregateFunction is a good reference.
   I did not do the change for the following reason:
   - mini-batch training is a common concept for machine learning.
   - Users may want to do operations before/after mini-batch training.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858420942


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/GeneralLinearAlgo.java:
##########
@@ -0,0 +1,414 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;

Review Comment:
   I have renamed package `glm` to `linear`.
   
   I have also marked all classes as `Internal` in this package and checked visibility of all methods in this package.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858230044


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java:
##########
@@ -18,51 +18,28 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
-import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.ml.api.Model;
-import org.apache.flink.ml.common.broadcast.BroadcastUtils;
-import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.glm.GeneralLinearAlgoModel;
 import org.apache.flink.ml.linalg.BLAS;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
-import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-import org.apache.flink.table.api.internal.TableImpl;
 import org.apache.flink.types.Row;
-import org.apache.flink.util.Preconditions;
 
 import org.apache.commons.lang3.ArrayUtils;
 
 import java.io.IOException;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
+import java.io.Serializable;
 
-/** A Model which classifies data using the model data computed by {@link LogisticRegression}. */
-public class LogisticRegressionModel
-        implements Model<LogisticRegressionModel>,
-                LogisticRegressionModelParams<LogisticRegressionModel> {
-
-    private final Map<Param<?>, Object> paramMap = new HashMap<>();
-
-    private Table modelDataTable;
-
-    public LogisticRegressionModel() {
-        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
-    }
-
-    @Override
-    public Map<Param<?>, Object> getParamMap() {
-        return paramMap;
-    }
+/**
+ * A Model which classifies the input using the model data computed by {@link LogisticRegression}.
+ */
+public class LogisticRegressionModel extends GeneralLinearAlgoModel<LogisticRegressionModel>
+        implements LogisticRegressionModelParams<LogisticRegressionModel>, Serializable {
 
     @Override
     public void save(String path) throws IOException {

Review Comment:
   If the change described for ModelData classes is feasible, then methods like `LogisticRegressionModel.save()` and `LogisticRegression.createModel()` might also be moved to corresponding abstract classes.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858226344


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/LocalTrainer.java:
##########
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+
+import java.util.Iterator;
+
+/**
+ * A local trainer is a trainer that uses a batch of training data to compute a update of the model
+ * locally.
+ *
+ * @param <T> Class type of training data.
+ */
+public abstract class LocalTrainer<T> implements CheckpointedFunction {

Review Comment:
   For `LocalTrainer` and its subclasses, how do you like the following changes?
   - move `CacheDataAndDoTrain` to an independent class
   - merge `CacheDataAndDoTrain`, `LocalTrainer` and `WithRegularization`
   - remove methods like `getReg`, as now they are only used internally
   - change the API and implementation of `trainOnBatchData`. I have the sense that each time only one data, instead of a batch of data, is enough. I think the API of `org.apache.flink.api.common.functions.AggregateFunction` is a good reference.
   - rename the merged class. A name like `LocalTrainer` might be too general to be associated with linear algorithms.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/GeneralLinearAlgo.java:
##########
@@ -0,0 +1,414 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.logisticregression.BinaryLogisticTrainer;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionTrainer;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Base class for general linear machine learning models.
+ *
+ * @param <E> Class type of {@link Estimator}.
+ * @param <M> Class type of {@link Model}.
+ */
+public abstract class GeneralLinearAlgo<

Review Comment:
   For linear algorithms' `Estimator` class, it might be better to make the following changes:
   - replaces `getModelType` with `getLocalTrainer()`
   - removes `LinearModelType`
   - makes `LocalTrainer` subclasses be private static classes of the `Estimator`s



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -42,12 +43,9 @@
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
-public class LogisticRegressionModelData {
-
-    public DenseVector coefficient;
-
+public class LogisticRegressionModelData extends GeneralLinearAlgoModelData {

Review Comment:
   `LogisticRegressionModelData` and `LinearRegressionModelData` seems to be different only in a type cast. If it is true, we can move most logics in this class to `GeneralLinearAlgoModelData`. In that case, `LinearRegressionModelData` only needs to specify a generic `T` or pass a `Class<T>` to the abstract `GeneralLinearAlgoModelData`'s constructor or super method.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/GeneralLinearAlgo.java:
##########
@@ -0,0 +1,414 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;

Review Comment:
   For introduced classes in this package:
   - Does `glm` stands for General Linear Model? I understand that the full package name might be too long to read, but `glm` might be too short to understand. Do you think there is a solution in between for the package name?
   - We can mark all classes in this package as `Internal`, and check if some `public` methods can be marked as `protected` or `private`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java:
##########
@@ -18,51 +18,28 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
-import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.ml.api.Model;
-import org.apache.flink.ml.common.broadcast.BroadcastUtils;
-import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.glm.GeneralLinearAlgoModel;
 import org.apache.flink.ml.linalg.BLAS;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
-import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-import org.apache.flink.table.api.internal.TableImpl;
 import org.apache.flink.types.Row;
-import org.apache.flink.util.Preconditions;
 
 import org.apache.commons.lang3.ArrayUtils;
 
 import java.io.IOException;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
+import java.io.Serializable;
 
-/** A Model which classifies data using the model data computed by {@link LogisticRegression}. */
-public class LogisticRegressionModel
-        implements Model<LogisticRegressionModel>,
-                LogisticRegressionModelParams<LogisticRegressionModel> {
-
-    private final Map<Param<?>, Object> paramMap = new HashMap<>();
-
-    private Table modelDataTable;
-
-    public LogisticRegressionModel() {
-        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
-    }
-
-    @Override
-    public Map<Param<?>, Object> getParamMap() {
-        return paramMap;
-    }
+/**
+ * A Model which classifies the input using the model data computed by {@link LogisticRegression}.
+ */
+public class LogisticRegressionModel extends GeneralLinearAlgoModel<LogisticRegressionModel>
+        implements LogisticRegressionModelParams<LogisticRegressionModel>, Serializable {
 
     @Override
     public void save(String path) throws IOException {

Review Comment:
   If the change described for ModelData classes is feasible, then methods like `LogisticRegressionModel.save()` and `LinearRegressionModel.createModel()` might also be moved to corresponding abstract classes.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860631049


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java:
##########
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.DenseVector;
+
+import java.io.Serializable;
+
+/**
+ * A loss function is to compute the loss and gradient with the given coefficient and training data.
+ */
+@Internal
+public interface LossFunc extends Serializable {

Review Comment:
   `LossFunc` and `Optmizer` are two orthogonal concepts. For `LossFunc`, there could be logistic loss, least square loss, hinge loss, etc. For `Optimizer`, there could be sgd, adam, l-bfgs, etc.
   
   Given a loss function, we can use different optimizers. Given a optimizer, we can optimize different loss functions.
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860652423


##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java:
##########
@@ -34,10 +34,16 @@ public static double asum(DenseVector x) {
     /** y += a * x . */
     public static void axpy(double a, Vector x, DenseVector y) {
         Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched.");
+        axpy(a, x, y, x.size());
+    }
+
+    /** y += a * x for the first k dimensions, with the other dimensions unchanged. */
+    public static void axpy(double a, Vector x, DenseVector y, int k) {

Review Comment:
   Now it is used for computing loss and update model in `BinaryLogisticLoss`, `LeastSquareLoss` and `SGD`. It is used to avoid create a new double array instance to improve the performance.
   
   A plan B could be: we remove this method in `BLAS` and add an internal utility function.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] zhipeng93 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

Posted by GitBox <gi...@apache.org>.
zhipeng93 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r860681944


##########
flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java:
##########
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModel;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionModelData;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+
+/** Tests {@link LinearRegression} and {@link LinearRegressionModel}. */
+public class LinearRegressionTest {
+
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private StreamExecutionEnvironment env;
+
+    private StreamTableEnvironment tEnv;
+
+    private static final List<Row> trainData =
+            Arrays.asList(
+                    Row.of(Vectors.dense(2, 1), 4.0, 1.0),
+                    Row.of(Vectors.dense(3, 2), 7.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 4), 10.0, 1.0),
+                    Row.of(Vectors.dense(2, 2), 6.0, 1.0),
+                    Row.of(Vectors.dense(4, 3), 10.0, 1.0),
+                    Row.of(Vectors.dense(1, 2), 5.0, 1.0),
+                    Row.of(Vectors.dense(5, 3), 11.0, 1.0));
+
+    private static final double[] expectedCoefficient = new double[] {1.0, 2.0};
+
+    private static final double TOLERANCE = 1e-7;
+
+    private Table trainDataTable;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Collections.shuffle(trainData);
+        trainDataTable =
+                tEnv.fromDataStream(
+                        env.fromCollection(
+                                trainData,
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(DenseVector.class),
+                                            Types.DOUBLE,
+                                            Types.DOUBLE
+                                        },
+                                        new String[] {"features", "label", "weight"})));
+    }
+
+    @SuppressWarnings("unchecked")
+    private void verifyPredictionResult(
+            Table output, String labelCol, String weightCol, String predictionCol)
+            throws Exception {
+        List<Row> predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        double lossSum = 0;
+        for (Row predictionRow : predResult) {
+            double label = (double) predictionRow.getField(labelCol);
+            double prediction = (double) predictionRow.getField(predictionCol);
+            double weight = (double) predictionRow.getField(weightCol);
+            lossSum += weight * Math.pow(label - prediction, 2);
+        }
+        assertTrue(lossSum < 1.0);

Review Comment:
   Sounds good. I have updated the PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org