You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/04/29 07:58:57 UTC

[GitHub] [flink-ml] lindong28 commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r861482738


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -67,6 +72,28 @@ public static <IN, OUT> DataStream<OUT> mapPartition(
                 .setParallelism(input.getParallelism());
     }
 
+    /**
+     * Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most
+     * one stream record and its parallelism is one.
+     *
+     * @param input The input data stream.
+     * @param func The user defined reduce function.
+     * @param <T> The class type of the input.
+     * @return The result data stream.
+     */
+    public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) {

Review Comment:
   Would it make the method more self-explanatory by changing the method name to `reduceAll`? This would make the name more consistent with `DataStream::windowAll` etc.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {

Review Comment:
   Should it be `getWeightedSum()`?
   
   Could we update Java doc and variables used in this file so that we mention `weighted sum` instead of `weight sum`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.
+         *
+         * @return The loss sum.
+         */
+        private double getLoss() {
+            return feedbackArray[coeffiDim + 1];
+        }
+
+        /**
+         * Sets the loss sum of the processed elements.
+         *
+         * @param loss The loss sum.
+         */
+        private void setLoss(double loss) {
+            feedbackArray[coeffiDim + 1] = loss;
+        }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<double[]> collector)
+                throws Exception {
+            if (epochWatermark == 0) {
+                coefficient = new DenseVector(feedbackArray);
+                coeffiDim = coefficient.size();
+                feedbackArray = new double[coefficient.size() + 2];
+            } else {
+                if (getWeightSum() > 0) {
+                    BLAS.axpy(
+                            -params.learningRate / getWeightSum(),
+                            new DenseVector(feedbackArray),
+                            coefficient,
+                            coeffiDim);
+                    double regLoss =
+                            RegularizationUtils.regularize(
+                                    coefficient,
+                                    params.reg,
+                                    params.elasticNet,
+                                    params.learningRate);
+                    setLoss(getLoss() + regLoss);
+                }
+            }
+
+            if (trainData == null) {
+                trainData = IteratorUtils.toList(trainDataState.get().iterator());
+            }
+
+            // TODO: supports efficient shuffle of training set on each partition.
+            if (trainData.size() > 0) {
+                List<LabeledPointWithWeight> miniBatchData =
+                        trainData.subList(
+                                nextBatchOffset,
+                                Math.min(nextBatchOffset + localBatchSize, trainData.size()));
+                nextBatchOffset += localBatchSize;
+                nextBatchOffset = nextBatchOffset >= trainData.size() ? 0 : nextBatchOffset;

Review Comment:
   In the previous implementation of `LogisticRegress`, each task randomly samples a local batch in each round. This PR changes this logic so that each task selects its local batch in a deterministic manner.
   
   Is the new approach strictly better than the other in terms of convergence rate and the converged accuracy? Which approach does Spark ML use for LogisticRegression and LinearRegression?
   



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/LogisticRegressionTest.java:
##########
@@ -186,6 +185,7 @@ public void testParam() {
         assertEquals(logisticRegression.getLearningRate(), 0.5, TOLERANCE);
         assertEquals(logisticRegression.getGlobalBatchSize(), 1000);
         assertEquals(logisticRegression.getReg(), 0.1, TOLERANCE);
+        assertEquals(logisticRegression.getElasticNet(), 0.5, TOLERANCE);

Review Comment:
   Now that we added the elasticNet param for `LogisticRegression`, do we need to test the case where `elasticNet` is different from its default value?
   
   According to the Java doc of this parameter, it seems that we might need to test three cases, e.g. with elasticNet = 0, 0.5, 1.
   
   We can add this as TODO and do it in a separate PR if it involves non-trivial work.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.
+     *
+     * @param bcInitModel The broadcast init model. Note that each task contains one DenseVector as

Review Comment:
   nits: `broadcast init model` -> `broadcasted init model`
   
   Given that the Java doc of `DataStreamUtils::allReduceSum` says `each partition contains one double array`, would it be more consistent to use `partition` instead of `task` here?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.
+     *
+     * @param bcInitModel The broadcast init model. Note that each task contains one DenseVector as
+     *     the model data and the model data on each task are exactly the same.
+     * @param trainData The training data.
+     * @param lossFunc The loss function to optimize.
+     * @return The fitted model. Note that the parallelism of the returned stream is one.
+     */
+    DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,

Review Comment:
   `Model` might be confused with the model class in Flink ML. How about we consistently use `modelData` to refer to the model data in the Java doc and function signature?
   
   I am not sure we need to specifically mention `initial` in the function signature, since this word does not seem necessary to understand the behavior of this function.
   
   Instead of calling the `broadcast()` outside this method, how about we invoke the broadcast inside this method, so that both the caller code and the method signature could be simpler?
   



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java:
##########
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/**
+ * An optimizer is a function to modify the weight of a machine learning model, which aims to find
+ * the optimal parameter configuration for a machine learning model. Examples of optimizers could be
+ * stochastic gradient descent (SGD), L-BFGS, etc.
+ */
+@Internal
+public interface Optimizer {
+    /**
+     * Optimize the given loss function using the init model and the training data.

Review Comment:
   nits: `Optimize the given loss` -> `Optimizes the given loss` to be consistent with other comments.
   
   Given that some algorithms in Flink ML could handle unbounded model data, would it useful to explicitly mention that `trainData` must be bounded?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java:
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/**
+ * A utility class for algorithms that need to handle regularization. The regularization term is
+ * defined as:
+ *
+ * <p>elasticNet * reg * norm1(coefficient) + (1 - elasticNet) * (reg/2) * (norm2(coefficient))^2
+ *
+ * <p>See https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html.
+ */
+@Internal
+class RegularizationUtils {
+
+    /**
+     * Regularize the model coefficient. The gradient of each dimension could be computed as:
+     * {elasticNet * reg * Math.sign(c_i) + (1 - elasticNet) * reg * c_i}. Here c_i is the value of
+     * coefficient at i-th dimension.
+     *
+     * @param coefficient The model coefficient.
+     * @param reg The reg param.
+     * @param elasticNet The elasticNet param.
+     * @param learningRate The learningRate param.
+     * @return The loss introduced by regularization.
+     */
+    public static double regularize(
+            DenseVector coefficient,
+            final double reg,
+            final double elasticNet,
+            final double learningRate) {
+
+        if (Double.compare(reg, 0) == 0) {
+            return 0;
+        } else {
+            if (Double.compare(elasticNet, 0) == 0) {
+                // Only L2 regularization.
+                double loss = reg / 2 * BLAS.norm2(coefficient);
+                BLAS.scal(1 - learningRate * reg, coefficient);
+                return loss;
+            } else if (Double.compare(elasticNet, 1) == 0) {
+                // Only L1 regularization.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    if (Double.compare(coefficientArray[i], 0) == 0) {
+                        continue;
+                    }
+                    loss += elasticNet * reg * Math.signum(coefficientArray[i]);
+                    coefficientArray[i] -=
+                            learningRate * elasticNet * reg * Math.signum(coefficientArray[i]);
+                }
+                return loss;
+            } else {
+                // Both L1 and L2 are not zero.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    loss +=
+                            elasticNet * reg * Math.signum(coefficientArray[i])
+                                    + (1 - elasticNet)
+                                            * (reg / 2)
+                                            * coefficientArray[i]
+                                            * coefficientArray[i];
+                    coefficientArray[i] =

Review Comment:
   nits: maybe change to `coefficientArray[i] -= ...` for consistency with the above code?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;

Review Comment:
   nits: given that we already use `coefficient` instead of `coeffi` above, how about renaming this variable as `coefficientDim` for consistency?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.

Review Comment:
   nits: `model update, weight and loss` -> `model update, weight, and loss`



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/RegularizationUtils.java:
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/**
+ * A utility class for algorithms that need to handle regularization. The regularization term is
+ * defined as:
+ *
+ * <p>elasticNet * reg * norm1(coefficient) + (1 - elasticNet) * (reg/2) * (norm2(coefficient))^2
+ *
+ * <p>See https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.ElasticNet.html.
+ */
+@Internal
+class RegularizationUtils {
+
+    /**
+     * Regularize the model coefficient. The gradient of each dimension could be computed as:
+     * {elasticNet * reg * Math.sign(c_i) + (1 - elasticNet) * reg * c_i}. Here c_i is the value of
+     * coefficient at i-th dimension.
+     *
+     * @param coefficient The model coefficient.
+     * @param reg The reg param.
+     * @param elasticNet The elasticNet param.
+     * @param learningRate The learningRate param.
+     * @return The loss introduced by regularization.
+     */
+    public static double regularize(
+            DenseVector coefficient,
+            final double reg,
+            final double elasticNet,
+            final double learningRate) {
+
+        if (Double.compare(reg, 0) == 0) {
+            return 0;
+        } else {
+            if (Double.compare(elasticNet, 0) == 0) {
+                // Only L2 regularization.
+                double loss = reg / 2 * BLAS.norm2(coefficient);
+                BLAS.scal(1 - learningRate * reg, coefficient);
+                return loss;
+            } else if (Double.compare(elasticNet, 1) == 0) {
+                // Only L1 regularization.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    if (Double.compare(coefficientArray[i], 0) == 0) {
+                        continue;
+                    }
+                    loss += elasticNet * reg * Math.signum(coefficientArray[i]);
+                    coefficientArray[i] -=
+                            learningRate * elasticNet * reg * Math.signum(coefficientArray[i]);
+                }
+                return loss;
+            } else {
+                // Both L1 and L2 are not zero.
+                double loss = 0;
+                double[] coefficientArray = coefficient.values;
+                for (int i = 0; i < coefficientArray.length; i++) {
+                    loss +=

Review Comment:
   Given that we checked `Double.compare(coefficientArray[i], 0) == 0` in the case where `elasticNet == 1`, do we also need to do the same check here?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/BinaryLogisticLoss.java:
##########
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.lossfunc;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+
+/** The loss function for binary logistic loss. See {@link LogisticRegression} for example. */
+@Internal
+public class BinaryLogisticLoss implements LossFunc {

Review Comment:
   This is pretty nice and clean.
   
   Since this is part of infra and might be used by multiple algorithms in the future, do we need to add unit tests for `BinaryLogisticLoss` and `LeastSquareLoss`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;

Review Comment:
   nits: The code above has line break between variable declarations. Could we make the variable declaration code follow the same style?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;

Review Comment:
   nits: We typically declare final variables before non-final variables. Could we update the code style as appropriate?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.

Review Comment:
   nits: Should it be `loss of ...` or `total loss of ...`?
   
   It might be simpler to just remove the Java doc here as this is a private function that its logic is pretty simple.
   
   Same for `setLoss(...)`, `setWeightSum(...)` and `getWeightSum()`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.

Review Comment:
   nits: `is consisted of` -> `consists of`
   
   Would it be better to change `{modelUpdate, weightSum, lossSum}` to `[modelUpdate, weightSum, lossSum]` to indicate that this is an array?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/regression/linearregression/LinearRegressionModel.java:
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.regression.linearregression;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which predicts data using the model data computed by {@link LinearRegression}. */
+public class LinearRegressionModel
+        implements Model<LinearRegressionModel>,
+                LinearRegressionModelParams<LinearRegressionModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    private Table modelDataTable;
+
+    public LinearRegressionModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {

Review Comment:
   nits: Currently the order of save/load/fit/getParamMap is not consistent across algorithms. I am thinking it might be nice to make them consistent. We can do it in a separate PR.
   
   Since `fit()` is the most important method for each Estimator, how about we order them as fit, save, load, getParamMap?
   
   And for Model, we can order them as transform, setModelData, getModelData, save, load, getParamMap.
   
   What do you think?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> modelVec.values)),
+                        ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized model data.
+            // In the following iterations, it contains: the model update, weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / value[value.length - 2];
+                                    })
+                            .flatMap(new TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration and updates the model
+     * iteratively. The first input is the training data, and the second input is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. */
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.
+         *
+         * @return The loss sum.
+         */
+        private double getLoss() {
+            return feedbackArray[coeffiDim + 1];
+        }
+
+        /**
+         * Sets the loss sum of the processed elements.
+         *
+         * @param loss The loss sum.
+         */
+        private void setLoss(double loss) {
+            feedbackArray[coeffiDim + 1] = loss;
+        }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<double[]> collector)
+                throws Exception {
+            if (epochWatermark == 0) {
+                coefficient = new DenseVector(feedbackArray);
+                coeffiDim = coefficient.size();
+                feedbackArray = new double[coefficient.size() + 2];
+            } else {
+                if (getWeightSum() > 0) {
+                    BLAS.axpy(
+                            -params.learningRate / getWeightSum(),
+                            new DenseVector(feedbackArray),
+                            coefficient,
+                            coeffiDim);
+                    double regLoss =
+                            RegularizationUtils.regularize(
+                                    coefficient,
+                                    params.reg,
+                                    params.elasticNet,
+                                    params.learningRate);
+                    setLoss(getLoss() + regLoss);
+                }
+            }
+
+            if (trainData == null) {
+                trainData = IteratorUtils.toList(trainDataState.get().iterator());
+            }
+
+            // TODO: supports efficient shuffle of training set on each partition.
+            if (trainData.size() > 0) {
+                List<LabeledPointWithWeight> miniBatchData =
+                        trainData.subList(
+                                nextBatchOffset,
+                                Math.min(nextBatchOffset + localBatchSize, trainData.size()));

Review Comment:
   It seems that the we could have `miniBatchData.size() < localBatchSize` even if `trainData.size() > localBatchSize`. Is it expected?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org