You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/04/26 03:41:18 UTC

[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #90: [FLINK-27093] Add Transformer and Estimator of LinearRegression

yunfengzhou-hub commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r858226344


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/LocalTrainer.java:
##########
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+
+import java.util.Iterator;
+
+/**
+ * A local trainer is a trainer that uses a batch of training data to compute a update of the model
+ * locally.
+ *
+ * @param <T> Class type of training data.
+ */
+public abstract class LocalTrainer<T> implements CheckpointedFunction {

Review Comment:
   For `LocalTrainer` and its subclasses, how do you like the following changes?
   - move `CacheDataAndDoTrain` to an independent class
   - merge `CacheDataAndDoTrain`, `LocalTrainer` and `WithRegularization`
   - remove methods like `getReg`, as now they are only used internally
   - change the API and implementation of `trainOnBatchData`. I have the sense that each time only one data, instead of a batch of data, is enough. I think the API of `org.apache.flink.api.common.functions.AggregateFunction` is a good reference.
   - rename the merged class. A name like `LocalTrainer` might be too general to be associated with linear algorithms.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/GeneralLinearAlgo.java:
##########
@@ -0,0 +1,414 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.logisticregression.BinaryLogisticTrainer;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.regression.linearregression.LinearRegressionTrainer;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Base class for general linear machine learning models.
+ *
+ * @param <E> Class type of {@link Estimator}.
+ * @param <M> Class type of {@link Model}.
+ */
+public abstract class GeneralLinearAlgo<

Review Comment:
   For linear algorithms' `Estimator` class, it might be better to make the following changes:
   - replaces `getModelType` with `getLocalTrainer()`
   - removes `LinearModelType`
   - makes `LocalTrainer` subclasses be private static classes of the `Estimator`s



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -42,12 +43,9 @@
  * <p>This class also provides methods to convert model data from Table to Datastream, and classes
  * to save/load model data.
  */
-public class LogisticRegressionModelData {
-
-    public DenseVector coefficient;
-
+public class LogisticRegressionModelData extends GeneralLinearAlgoModelData {

Review Comment:
   `LogisticRegressionModelData` and `LinearRegressionModelData` seems to be different only in a type cast. If it is true, we can move most logics in this class to `GeneralLinearAlgoModelData`. In that case, `LinearRegressionModelData` only needs to specify a generic `T` or pass a `Class<T>` to the abstract `GeneralLinearAlgoModelData`'s constructor or super method.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/glm/GeneralLinearAlgo.java:
##########
@@ -0,0 +1,414 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.glm;

Review Comment:
   For introduced classes in this package:
   - Does `glm` stands for General Linear Model? I understand that the full package name might be too long to read, but `glm` might be too short to understand. Do you think there is a solution in between for the package name?
   - We can mark all classes in this package as `Internal`, and check if some `public` methods can be marked as `protected` or `private`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java:
##########
@@ -18,51 +18,28 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
-import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.ml.api.Model;
-import org.apache.flink.ml.common.broadcast.BroadcastUtils;
-import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.glm.GeneralLinearAlgoModel;
 import org.apache.flink.ml.linalg.BLAS;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.param.Param;
-import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
-import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-import org.apache.flink.table.api.internal.TableImpl;
 import org.apache.flink.types.Row;
-import org.apache.flink.util.Preconditions;
 
 import org.apache.commons.lang3.ArrayUtils;
 
 import java.io.IOException;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
+import java.io.Serializable;
 
-/** A Model which classifies data using the model data computed by {@link LogisticRegression}. */
-public class LogisticRegressionModel
-        implements Model<LogisticRegressionModel>,
-                LogisticRegressionModelParams<LogisticRegressionModel> {
-
-    private final Map<Param<?>, Object> paramMap = new HashMap<>();
-
-    private Table modelDataTable;
-
-    public LogisticRegressionModel() {
-        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
-    }
-
-    @Override
-    public Map<Param<?>, Object> getParamMap() {
-        return paramMap;
-    }
+/**
+ * A Model which classifies the input using the model data computed by {@link LogisticRegression}.
+ */
+public class LogisticRegressionModel extends GeneralLinearAlgoModel<LogisticRegressionModel>
+        implements LogisticRegressionModelParams<LogisticRegressionModel>, Serializable {
 
     @Override
     public void save(String path) throws IOException {

Review Comment:
   If the change described for ModelData classes is feasible, then methods like `LogisticRegressionModel.save()` and `LinearRegressionModel.createModel()` might also be moved to corresponding abstract classes.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org