You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/11/30 07:08:03 UTC

[GitHub] [flink-ml] jiangxin369 opened a new pull request, #187: [FLINK-29601] Add Estimator and Transformer for UnivariateFeatureSelector

jiangxin369 opened a new pull request, #187:
URL: https://github.com/apache/flink-ml/pull/187

   <!--
   *Thank you very much for contributing to Apache Flink ML - we are happy that you want to help us improve Flink ML. To help the community review your contribution in the best possible way, please go through the checklist below, which will get the contribution into a shape in which it can be best reviewed.*
   
   ## Contribution Checklist
   
     - Make sure that the pull request corresponds to one [JIRA issue](https://issues.apache.org/jira/projects/FLINK/issues). Exceptions are made for typos in JavaDoc or documentation files, which need no JIRA issue.
     
     - Name the pull request in the form "[FLINK-XXXX] Title of the pull request", where *FLINK-XXXX* should be replaced by the actual issue number.
     Typo fixes that have no associated JIRA issue should be named following this pattern: `[hotfix] Title of the pull request`.
   
     - Fill out the template below to describe the changes contributed by the pull request. That will give reviewers the context they need to do the review.
     
     - Each commit in the pull request has a meaningful commit message (including the JIRA id)
   
     - Once all items of the checklist are addressed, remove the above text and this checklist, leaving only the filled out template below.
   
   **(The sections below can be removed for hotfixes of typos)**
   -->
   
   ## What is the purpose of the change
   
   Add Estimator and Transformer for UnivariateFeatureSelector
   
   ## Brief change log
   
     - Adds Transformer and Estimator implementation of UnivariateFeatureSelector in Java and Python
     - Adds examples and documentation of UnivariateFeatureSelector
   
   ## Does this pull request potentially affect one of the following parts:
   
     - Dependencies (does it add or upgrade a dependency): (no)
     - The public API, i.e., is any changed class annotated with `@Public(Evolving)`: (no)
   
   ## Documentation
   
     - Does this pull request introduce a new feature? (yes)
     - If yes, how is the feature documented? (docs / JavaDocs)
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #187: [FLINK-29601] Add Estimator and Transformer for UnivariateFeatureSelector

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #187:
URL: https://github.com/apache/flink-ml/pull/187#discussion_r1045234776


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A Model which transforms data using the model data computed by {@link UnivariateFeatureSelector}.
+ */
+public class UnivariateFeatureSelectorModel
+        implements Model<UnivariateFeatureSelectorModel>,
+                UnivariateFeatureSelectorModelParams<UnivariateFeatureSelectorModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public UnivariateFeatureSelectorModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public UnivariateFeatureSelectorModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<UnivariateFeatureSelectorModelData> modelData =
+                UnivariateFeatureSelectorModelData.getModelDataStream(modelDataTable);
+
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+        DataStream<Row> outputStream =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, modelData),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictOutputFunction(getFeaturesCol(), broadcastModelKey),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(outputStream)};
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+
+        private final String inputCol;
+        private final String broadcastKey;
+        private int[] indices;
+
+        public PredictOutputFunction(String inputCol, String broadcastKey) {
+            this.inputCol = inputCol;
+            this.broadcastKey = broadcastKey;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (indices == null) {
+                UnivariateFeatureSelectorModelData modelData =
+                        (UnivariateFeatureSelectorModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+                indices = Arrays.stream(modelData.indices).sorted().toArray();
+            }
+
+            if (indices.length == 0) {
+                return Row.join(row, Row.of(Vectors.dense()));
+            } else {
+                Vector inputVec = ((Vector) row.getField(inputCol));
+                Preconditions.checkArgument(
+                        inputVec.size() > indices[indices.length - 1],
+                        "Input %s features, but UnivariateFeatureSelector is "
+                                + "expecting at least %s features as input.",
+                        inputVec.size(),
+                        indices[indices.length - 1] + 1);
+                Vector outputVec = selectByIndices(inputVec, indices);
+                return Row.join(row, Row.of(outputVec));
+            }
+        }
+
+        /**
+         * Selects a subset of the vector base on the indices. Note that the input indices must be
+         * sorted in ascending order if the input vector is sparse.
+         */
+        private Vector selectByIndices(Vector vector, int[] selectedIndices) {

Review Comment:
   `DataStreamUtils#sample(...)` are also called only by two classes.
   
   It is OK for methods to assume that the input meets certain criteria (e.g. not null). Requiring the inputs to be sorted is just one special requirement and we can document it properly in the util method's Java doc. We can also name the variable as `sortedIndices` to make the requirement self-explanatory.
   
   The benefit of adding util method is to reduce duplicated code. Is there other downside of making it util method?
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 merged pull request #187: [FLINK-29601] Add Estimator and Transformer for UnivariateFeatureSelector

Posted by GitBox <gi...@apache.org>.
lindong28 merged PR #187:
URL: https://github.com/apache/flink-ml/pull/187


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] jiangxin369 commented on a diff in pull request #187: [FLINK-29601] Add Estimator and Transformer for UnivariateFeatureSelector

Posted by GitBox <gi...@apache.org>.
jiangxin369 commented on code in PR #187:
URL: https://github.com/apache/flink-ml/pull/187#discussion_r1045197314


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A Model which transforms data using the model data computed by {@link UnivariateFeatureSelector}.
+ */
+public class UnivariateFeatureSelectorModel
+        implements Model<UnivariateFeatureSelectorModel>,
+                UnivariateFeatureSelectorModelParams<UnivariateFeatureSelectorModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public UnivariateFeatureSelectorModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public UnivariateFeatureSelectorModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<UnivariateFeatureSelectorModelData> modelData =
+                UnivariateFeatureSelectorModelData.getModelDataStream(modelDataTable);
+
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+        DataStream<Row> outputStream =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, modelData),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictOutputFunction(getFeaturesCol(), broadcastModelKey),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(outputStream)};
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+
+        private final String inputCol;
+        private final String broadcastKey;
+        private int[] indices;
+
+        public PredictOutputFunction(String inputCol, String broadcastKey) {
+            this.inputCol = inputCol;
+            this.broadcastKey = broadcastKey;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (indices == null) {
+                UnivariateFeatureSelectorModelData modelData =
+                        (UnivariateFeatureSelectorModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+                indices = Arrays.stream(modelData.indices).sorted().toArray();
+            }
+
+            if (indices.length == 0) {
+                return Row.join(row, Row.of(Vectors.dense()));
+            } else {
+                Vector inputVec = ((Vector) row.getField(inputCol));
+                Preconditions.checkArgument(
+                        inputVec.size() > indices[indices.length - 1],
+                        "Input %s features, but UnivariateFeatureSelector is "
+                                + "expecting at least %s features as input.",
+                        inputVec.size(),
+                        indices[indices.length - 1] + 1);
+                Vector outputVec = selectByIndices(inputVec, indices);
+                return Row.join(row, Row.of(outputVec));
+            }
+        }
+
+        /**
+         * Selects a subset of the vector base on the indices. Note that the input indices must be
+         * sorted in ascending order if the input vector is sparse.
+         */
+        private Vector selectByIndices(Vector vector, int[] selectedIndices) {

Review Comment:
   I prefer to keep this method private because this method requires the input indices to be sorted but without any check for better performance, so I don't want to expose it as public for only two operators.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #187: [FLINK-29601] Add Estimator and Transformer for UnivariateFeatureSelector

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #187:
URL: https://github.com/apache/flink-ml/pull/187#discussion_r1044092761


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelector.java:
##########
@@ -0,0 +1,305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.stats.anovatest.ANOVATest;
+import org.apache.flink.ml.stats.chisqtest.ChiSqTest;
+import org.apache.flink.ml.stats.fvaluetest.FValueTest;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.IntStream;
+
+/**
+ * An Estimator which selects features based on univariate statistical tests against labels.
+ *
+ * <p>Currently, Flink supports three Univariate Feature Selectors: chi-squared, ANOVA F-test and
+ * F-value. User can choose Univariate Feature Selector by setting `featureType` and `labelType`,
+ * and Flink will pick the score function based on the specified `featureType` and `labelType`.
+ *
+ * <p>The following combination of `featureType` and `labelType` are supported:
+ *
+ * <ul>
+ *   <li>`featureType` `categorical` and `labelType` `categorical`: Flink uses chi-squared, i.e.
+ *       chi2 in sklearn.
+ *   <li>`featureType` `continuous` and `labelType` `categorical`: Flink uses ANOVA F-test,
+ *       f_classif in sklearn.

Review Comment:
   nits: add `i.e.` for consistency.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A Model which transforms data using the model data computed by {@link UnivariateFeatureSelector}.
+ */
+public class UnivariateFeatureSelectorModel
+        implements Model<UnivariateFeatureSelectorModel>,
+                UnivariateFeatureSelectorModelParams<UnivariateFeatureSelectorModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public UnivariateFeatureSelectorModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public UnivariateFeatureSelectorModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<UnivariateFeatureSelectorModelData> modelData =
+                UnivariateFeatureSelectorModelData.getModelDataStream(modelDataTable);
+
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+        DataStream<Row> outputStream =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, modelData),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictOutputFunction(getFeaturesCol(), broadcastModelKey),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(outputStream)};
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+
+        private final String inputCol;
+        private final String broadcastKey;
+        private int[] indices;
+
+        public PredictOutputFunction(String inputCol, String broadcastKey) {
+            this.inputCol = inputCol;
+            this.broadcastKey = broadcastKey;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (indices == null) {
+                UnivariateFeatureSelectorModelData modelData =
+                        (UnivariateFeatureSelectorModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+                indices = Arrays.stream(modelData.indices).sorted().toArray();
+            }
+
+            if (indices.length == 0) {
+                return Row.join(row, Row.of(Vectors.dense()));
+            } else {
+                Vector inputVec = ((Vector) row.getField(inputCol));
+                Preconditions.checkArgument(
+                        inputVec.size() > indices[indices.length - 1],
+                        "Input %s features, but UnivariateFeatureSelector is "
+                                + "expecting at least %s features as input.",
+                        inputVec.size(),
+                        indices[indices.length - 1] + 1);
+                Vector outputVec = selectByIndices(inputVec, indices);
+                return Row.join(row, Row.of(outputVec));
+            }
+        }
+
+        /**
+         * Selects a subset of the vector base on the indices. Note that the input indices must be
+         * sorted in ascending order if the input vector is sparse.
+         */
+        private Vector selectByIndices(Vector vector, int[] selectedIndices) {

Review Comment:
   Since this method looks pretty generic and it is re-used by `VarianceThresholdSelectorModel.java`, how about we move this method to `org.apache.flink.ml.common.util.VectorUtils` and make it static?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on a diff in pull request #187: [FLINK-29601] Add Estimator and Transformer for UnivariateFeatureSelector

Posted by GitBox <gi...@apache.org>.
lindong28 commented on code in PR #187:
URL: https://github.com/apache/flink-ml/pull/187#discussion_r1045233434


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A Model which transforms data using the model data computed by {@link UnivariateFeatureSelector}.
+ */
+public class UnivariateFeatureSelectorModel
+        implements Model<UnivariateFeatureSelectorModel>,
+                UnivariateFeatureSelectorModelParams<UnivariateFeatureSelectorModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public UnivariateFeatureSelectorModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public UnivariateFeatureSelectorModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<UnivariateFeatureSelectorModelData> modelData =
+                UnivariateFeatureSelectorModelData.getModelDataStream(modelDataTable);
+
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+        DataStream<Row> outputStream =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, modelData),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictOutputFunction(getFeaturesCol(), broadcastModelKey),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(outputStream)};
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+
+        private final String inputCol;
+        private final String broadcastKey;
+        private int[] indices;
+
+        public PredictOutputFunction(String inputCol, String broadcastKey) {
+            this.inputCol = inputCol;
+            this.broadcastKey = broadcastKey;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (indices == null) {
+                UnivariateFeatureSelectorModelData modelData =
+                        (UnivariateFeatureSelectorModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+                indices = Arrays.stream(modelData.indices).sorted().toArray();
+            }
+
+            if (indices.length == 0) {
+                return Row.join(row, Row.of(Vectors.dense()));
+            } else {
+                Vector inputVec = ((Vector) row.getField(inputCol));
+                Preconditions.checkArgument(
+                        inputVec.size() > indices[indices.length - 1],
+                        "Input %s features, but UnivariateFeatureSelector is "
+                                + "expecting at least %s features as input.",
+                        inputVec.size(),
+                        indices[indices.length - 1] + 1);
+                Vector outputVec = selectByIndices(inputVec, indices);
+                return Row.join(row, Row.of(outputVec));
+            }
+        }
+
+        /**
+         * Selects a subset of the vector base on the indices. Note that the input indices must be
+         * sorted in ascending order if the input vector is sparse.
+         */
+        private Vector selectByIndices(Vector vector, int[] selectedIndices) {
+            if (vector instanceof DenseVector) {
+                DenseVector resultVec = new DenseVector(selectedIndices.length);
+                for (int i = 0; i < selectedIndices.length; i++) {
+                    resultVec.set(i, vector.get(selectedIndices[i]));
+                }
+                return resultVec;
+            } else {
+                List<Integer> resultIndices = new ArrayList<>();
+                List<Double> resultValues = new ArrayList<>();
+
+                int[] indices = ((SparseVector) vector).indices;
+                for (int i = 0, j = 0; i < indices.length && j < selectedIndices.length; ) {
+                    if (indices[i] == selectedIndices[j]) {
+                        resultIndices.add(j++);
+                        resultValues.add(((SparseVector) vector).values[i++]);
+                    } else if (indices[i] > selectedIndices[j]) {
+                        j++;
+                    } else {
+                        i++;
+                    }
+                }
+                return new SparseVector(

Review Comment:
   The SparseVector instance created here is guaranteed to be dense. Would it be simpler to just return DenseVector here?



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java:
##########
@@ -0,0 +1,783 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector;
+import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+
+/** Tests {@link UnivariateFeatureSelector} and {@link UnivariateFeatureSelectorModel}. */
+public class UnivariateFeatureSelectorTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table inputChiSqTable;
+    private Table inputANOVATable;
+    private Table inputFValueTable;
+
+    private static final double EPS = 1.0e-5;
+
+    private UnivariateFeatureSelector selectorWithChiSqTest;
+    private UnivariateFeatureSelector selectorWithANOVATest;
+    private UnivariateFeatureSelector selectorWithFValueTest;
+
+    private static final List<Row> INPUT_CHISQ_DATA =
+            Arrays.asList(
+                    Row.of(0.0, Vectors.dense(6.0, 7.0, 0.0, 7.0, 6.0, 0.0)),
+                    Row.of(1.0, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0, 9.0)),
+                    Row.of(1.0, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0)),
+                    Row.of(1.0, Vectors.dense(0.0, 9.0, 8.0, 5.0, 6.0, 4.0).toSparse()),
+                    Row.of(2.0, Vectors.dense(8.0, 9.0, 6.0, 5.0, 4.0, 4.0).toSparse()),
+                    Row.of(2.0, Vectors.dense(8.0, 9.0, 6.0, 4.0, 0.0, 0.0).toSparse()));
+
+    private static final List<Row> INPUT_ANOVA_DATA =
+            Arrays.asList(
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    4.65415496e-03,
+                                    1.03550567e-01,
+                                    -1.17358140e+00,
+                                    1.61408773e-01,
+                                    3.92492111e-01,
+                                    7.31240882e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    -9.01651741e-01,
+                                    -5.28905302e-01,
+                                    1.27636785e+00,
+                                    7.02154563e-01,
+                                    6.21348351e-01,
+                                    1.88397353e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    3.85692159e-01,
+                                    -9.04639637e-01,
+                                    5.09782604e-02,
+                                    8.40043971e-01,
+                                    7.45977857e-01,
+                                    8.78402288e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    1.36264353e+00,
+                                    2.62454094e-01,
+                                    7.96306202e-01,
+                                    6.14948000e-01,
+                                    7.44948187e-01,
+                                    9.74034830e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    9.65874070e-01,
+                                    2.52773665e+00,
+                                    -2.19380094e+00,
+                                    2.33408080e-01,
+                                    1.86340919e-01,
+                                    8.23390433e-01)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.12324305e+01,
+                                    -2.77121515e-01,
+                                    1.12740513e-01,
+                                    2.35184013e-01,
+                                    3.46668895e-01,
+                                    9.38500782e-02)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.06195839e+01,
+                                    -1.82891238e+00,
+                                    2.25085601e-01,
+                                    9.09979851e-01,
+                                    6.80257535e-02,
+                                    8.24017480e-01)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.12806837e+01,
+                                    1.30686889e+00,
+                                    9.32839108e-02,
+                                    3.49784755e-01,
+                                    1.71322408e-02,
+                                    7.48465194e-02)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    9.98689462e+00,
+                                    9.50808938e-01,
+                                    -2.90786359e-01,
+                                    2.31253009e-01,
+                                    7.46270968e-01,
+                                    1.60308169e-01)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.08428551e+01,
+                                    -1.02749936e+00,
+                                    1.73951508e-01,
+                                    8.92482744e-02,
+                                    1.42651730e-01,
+                                    7.66751625e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                            -1.98641448e+00,
+                                            1.12811990e+01,
+                                            -2.35246756e-01,
+                                            8.22809049e-01,
+                                            3.26739456e-01,
+                                            7.88268404e-01)
+                                    .toSparse()),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                            -6.09864090e-01,
+                                            1.07346276e+01,
+                                            -2.18805509e-01,
+                                            7.33931213e-01,
+                                            1.42554396e-01,
+                                            7.11225605e-01)
+                                    .toSparse()),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                            -1.58481268e+00,
+                                            9.19364039e+00,
+                                            -5.87490459e-02,
+                                            2.51532056e-01,
+                                            2.82729807e-01,
+                                            7.16245686e-01)
+                                    .toSparse()),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                            -2.50949277e-01,
+                                            1.12815254e+01,
+                                            -6.94806734e-01,
+                                            5.93898886e-01,
+                                            5.68425656e-01,
+                                            8.49762330e-01)
+                                    .toSparse()),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                            7.63485129e-01,
+                                            1.02605138e+01,
+                                            1.32617719e+00,
+                                            5.49682879e-01,
+                                            8.59931442e-01,
+                                            4.88677978e-02)
+                                    .toSparse()),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                            9.34900015e-01,
+                                            4.11379043e-01,
+                                            8.65010205e+00,
+                                            9.23509168e-01,
+                                            1.16995043e-01,
+                                            5.91894106e-03)
+                                    .toSparse()),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                            4.73734933e-01,
+                                            -1.48321181e+00,
+                                            9.73349621e+00,
+                                            4.09421563e-01,
+                                            5.09375719e-01,
+                                            5.93157850e-01)
+                                    .toSparse()),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                            3.41470679e-01,
+                                            -6.88972582e-01,
+                                            9.60347938e+00,
+                                            3.62654055e-01,
+                                            2.43437468e-01,
+                                            7.13052838e-01)
+                                    .toSparse()),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                            -5.29614251e-01,
+                                            -1.39262856e+00,
+                                            1.01354144e+01,
+                                            8.24123861e-01,
+                                            5.84074506e-01,
+                                            6.54461558e-01)
+                                    .toSparse()),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                            -2.99454508e-01,
+                                            2.20457263e+00,
+                                            1.14586015e+01,
+                                            5.16336729e-01,
+                                            9.99776159e-01,
+                                            3.15769738e-01)
+                                    .toSparse()));
+
+    private static final List<Row> INPUT_FVALUE_DATA =
+            Arrays.asList(
+                    Row.of(
+                            0.52516321,
+                            Vectors.dense(
+                                    0.19151945,
+                                    0.62210877,
+                                    0.43772774,
+                                    0.78535858,
+                                    0.77997581,
+                                    0.27259261)),
+                    Row.of(
+                            0.88275782,
+                            Vectors.dense(
+                                    0.27646426,
+                                    0.80187218,
+                                    0.95813935,
+                                    0.87593263,
+                                    0.35781727,
+                                    0.50099513)),
+                    Row.of(
+                            0.67524507,
+                            Vectors.dense(
+                                    0.68346294,
+                                    0.71270203,
+                                    0.37025075,
+                                    0.56119619,
+                                    0.50308317,
+                                    0.01376845)),
+                    Row.of(
+                            0.76734745,
+                            Vectors.dense(
+                                    0.77282662,
+                                    0.88264119,
+                                    0.36488598,
+                                    0.61539618,
+                                    0.07538124,
+                                    0.36882401)),
+                    Row.of(
+                            0.73909458,
+                            Vectors.dense(
+                                    0.9331401,
+                                    0.65137814,
+                                    0.39720258,
+                                    0.78873014,
+                                    0.31683612,
+                                    0.56809865)),
+                    Row.of(
+                            0.83628141,
+                            Vectors.dense(
+                                    0.86912739,
+                                    0.43617342,
+                                    0.80214764,
+                                    0.14376682,
+                                    0.70426097,
+                                    0.70458131)),
+                    Row.of(
+                            0.65665506,
+                            Vectors.dense(
+                                    0.21879211,
+                                    0.92486763,
+                                    0.44214076,
+                                    0.90931596,
+                                    0.05980922,
+                                    0.18428708)),
+                    Row.of(
+                            0.58147135,
+                            Vectors.dense(
+                                    0.04735528,
+                                    0.67488094,
+                                    0.59462478,
+                                    0.53331016,
+                                    0.04332406,
+                                    0.56143308)),
+                    Row.of(
+                            0.35603443,
+                            Vectors.dense(
+                                    0.32966845,
+                                    0.50296683,
+                                    0.11189432,
+                                    0.60719371,
+                                    0.56594464,
+                                    0.00676406)),
+                    Row.of(
+                            0.94534373,
+                            Vectors.dense(
+                                    0.61744171,
+                                    0.91212289,
+                                    0.79052413,
+                                    0.99208147,
+                                    0.95880176,
+                                    0.79196414)),
+                    Row.of(
+                            0.57458887,
+                            Vectors.dense(
+                                            0.28525096,
+                                            0.62491671,
+                                            0.4780938,
+                                            0.19567518,
+                                            0.38231745,
+                                            0.05387369)
+                                    .toSparse()),
+                    Row.of(
+                            0.59026777,
+                            Vectors.dense(
+                                            0.45164841,
+                                            0.98200474,
+                                            0.1239427,
+                                            0.1193809,
+                                            0.73852306,
+                                            0.58730363)
+                                    .toSparse()),
+                    Row.of(
+                            0.29894977,
+                            Vectors.dense(
+                                            0.47163253,
+                                            0.10712682,
+                                            0.22921857,
+                                            0.89996519,
+                                            0.41675354,
+                                            0.53585166)
+                                    .toSparse()),
+                    Row.of(
+                            0.34056582,
+                            Vectors.dense(
+                                            0.00620852,
+                                            0.30064171,
+                                            0.43689317,
+                                            0.612149,
+                                            0.91819808,
+                                            0.62573667)
+                                    .toSparse()),
+                    Row.of(
+                            0.64476446,
+                            Vectors.dense(
+                                            0.70599757,
+                                            0.14983372,
+                                            0.74606341,
+                                            0.83100699,
+                                            0.63372577,
+                                            0.43830988)
+                                    .toSparse()),
+                    Row.of(
+                            0.53724782,
+                            Vectors.dense(
+                                            0.15257277,
+                                            0.56840962,
+                                            0.52822428,
+                                            0.95142876,
+                                            0.48035918,
+                                            0.50255956)
+                                    .toSparse()),
+                    Row.of(
+                            0.5173021,
+                            Vectors.dense(
+                                            0.53687819,
+                                            0.81920207,
+                                            0.05711564,
+                                            0.66942174,
+                                            0.76711663,
+                                            0.70811536)
+                                    .toSparse()),
+                    Row.of(
+                            0.94508275,
+                            Vectors.dense(
+                                            0.79686718,
+                                            0.55776083,
+                                            0.96583653,
+                                            0.1471569,
+                                            0.029647,
+                                            0.59389349)
+                                    .toSparse()),
+                    Row.of(
+                            0.57739736,
+                            Vectors.dense(
+                                            0.1140657,
+                                            0.95080985,
+                                            0.96583653,
+                                            0.19361869,
+                                            0.45781165,
+                                            0.92040257)
+                                    .toSparse()),
+                    Row.of(
+                            0.53877145,
+                            Vectors.dense(
+                                            0.87906916,
+                                            0.25261576,
+                                            0.34800879,
+                                            0.18258873,
+                                            0.90179605,
+                                            0.70652816)
+                                    .toSparse()));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        env.getConfig().enableObjectReuse();
+        tEnv = StreamTableEnvironment.create(env);
+
+        selectorWithChiSqTest =
+                new UnivariateFeatureSelector()
+                        .setFeatureType("categorical")
+                        .setLabelType("categorical");
+        selectorWithANOVATest =
+                new UnivariateFeatureSelector()
+                        .setFeatureType("continuous")
+                        .setLabelType("categorical");
+        selectorWithFValueTest =
+                new UnivariateFeatureSelector()
+                        .setFeatureType("continuous")
+                        .setLabelType("continuous");
+        inputChiSqTable =
+                tEnv.fromDataStream(
+                                env.fromCollection(
+                                        INPUT_CHISQ_DATA,
+                                        Types.ROW(Types.DOUBLE, VectorTypeInfo.INSTANCE)))
+                        .as("label", "features");
+        inputANOVATable =
+                tEnv.fromDataStream(
+                                env.fromCollection(
+                                        INPUT_ANOVA_DATA,
+                                        Types.ROW(Types.INT, VectorTypeInfo.INSTANCE)))
+                        .as("label", "features");
+        inputFValueTable =
+                tEnv.fromDataStream(
+                                env.fromCollection(
+                                        INPUT_FVALUE_DATA,
+                                        Types.ROW(Types.DOUBLE, VectorTypeInfo.INSTANCE)))
+                        .as("label", "features");
+    }
+
+    private void transformAndVerify(
+            UnivariateFeatureSelector selector, Table table, int... expectedIndices)
+            throws Exception {
+        UnivariateFeatureSelectorModel model = selector.fit(table);
+        Table output = model.transform(table)[0];
+        verifyOutputResult(output, expectedIndices);
+    }
+
+    private void verifyOutputResult(Table table, int... expectedIndices) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) table).getTableEnvironment();
+        CloseableIterator<Row> rowIterator = tEnv.toDataStream(table).executeAndCollect();
+        while (rowIterator.hasNext()) {
+            Row row = rowIterator.next();
+            for (int i = 0; i < expectedIndices.length; i++) {
+                assertEquals(
+                        ((Vector) row.getField("features")).get(expectedIndices[i]),
+                        ((Vector) row.getField("output")).get(i),

Review Comment:
   Should we also verify that the given `table` only contains the selected features?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #187: [FLINK-29601] Add Estimator and Transformer for UnivariateFeatureSelector

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #187:
URL: https://github.com/apache/flink-ml/pull/187#discussion_r1035870193


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModelData.java:
##########
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link UnivariateFeatureSelectorModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to a data stream, and
+ * classes to save/load model data.
+ */
+public class UnivariateFeatureSelectorModelData {
+
+    public int[] indices;

Review Comment:
   Let's add a brief JavaDoc for the meaning of this variable.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModelParams.java:
##########
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
+
+/**
+ * Params for {@link UnivariateFeatureSelectorModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface UnivariateFeatureSelectorModelParams<T>
+        extends HasFeaturesCol<T>, HasLabelCol<T>, HasOutputCol<T> {}

Review Comment:
   `UnivariateFeatureSelectorModel` does not need labelCol.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java:
##########
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which transforms data using the model data computed by {@link UnivariateFeatureSelector}.
+ */
+public class UnivariateFeatureSelectorModel
+        implements Model<UnivariateFeatureSelectorModel>,
+                UnivariateFeatureSelectorModelParams<UnivariateFeatureSelectorModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public UnivariateFeatureSelectorModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public UnivariateFeatureSelectorModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<UnivariateFeatureSelectorModelData> modelData =
+                UnivariateFeatureSelectorModelData.getModelDataStream(modelDataTable);
+
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), DenseVectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, modelData),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictOutputFunction(getFeaturesCol(), broadcastModelKey),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+
+        private final String inputCol;
+        private final String broadcastKey;
+        private int[] indices = null;
+
+        public PredictOutputFunction(String inputCol, String broadcastKey) {
+            this.inputCol = inputCol;
+            this.broadcastKey = broadcastKey;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (indices == null) {
+                UnivariateFeatureSelectorModelData modelData =
+                        (UnivariateFeatureSelectorModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+                indices = modelData.indices;
+            }
+
+            if (indices.length == 0) {
+                return Row.join(row, Row.of(Vectors.dense()));
+            } else {
+                Vector inputVec = ((Vector) row.getField(inputCol));
+                Preconditions.checkArgument(
+                        inputVec.size() > indices[indices.length - 1],

Review Comment:
   Let's scan the indices to find the maximum possible index, or add checks and documents to ensure that the indices in model data must be sorted.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelector.java:
##########
@@ -0,0 +1,301 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.stats.anovatest.ANOVATest;
+import org.apache.flink.ml.stats.chisqtest.ChiSqTest;
+import org.apache.flink.ml.stats.fvaluetest.FValueTest;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.IntStream;
+
+/**
+ * An Estimator which selects features based on univariate statistical tests against labels.
+ *
+ * <p>Currently, Flink supports three Univariate Feature Selectors: chi-squared, ANOVA F-test and
+ * F-value. User can choose Univariate Feature Selector by setting `featureType` and `labelType`,
+ * and Flink will pick the score function based on the specified `featureType` and `labelType`.
+ *
+ * <p>The following combination of `featureType` and `labelType` are supported:
+ *
+ * <ul>
+ *   <li>`featureType` `categorical` and `labelType` `categorical`: Flink uses chi-squared, i.e.
+ *       chi2 in sklearn.
+ *   <li>`featureType` `continuous` and `labelType` `categorical`: Flink uses ANOVA F-test,
+ *       f_classif in sklearn.
+ *   <li>`featureType` `continuous` and `labelType` `continuous`: Flink uses F-value, i.e.
+ *       f_regression in sklearn.
+ * </ul>
+ *
+ * <p>The `UnivariateFeatureSelector` supports different selection modes:
+ *
+ * <ul>
+ *   <li>numTopFeatures: chooses a fixed number of top features according to a hypothesis.
+ *   <li>percentile: similar to numTopFeatures but chooses a fraction of all features instead of a
+ *       fixed number.
+ *   <li>fpr: chooses all features whose p-value are below a threshold, thus controlling the false
+ *       positive rate of selection.
+ *   <li>fdr: uses the <a
+ *       href="https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure">
+ *       Benjamini-Hochberg procedure</a> to choose all features whose false discovery rate is below
+ *       a threshold.
+ *   <li>fwe: chooses all features whose p-values are below a threshold. The threshold is scaled by
+ *       1/numFeatures, thus controlling the family-wise error rate of selection.
+ * </ul>
+ *
+ * <p>By default, the selection mode is `numTopFeatures`.
+ */
+public class UnivariateFeatureSelector
+        implements Estimator<UnivariateFeatureSelector, UnivariateFeatureSelectorModel>,
+                UnivariateFeatureSelectorParams<UnivariateFeatureSelector> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public UnivariateFeatureSelector() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public UnivariateFeatureSelectorModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String featuresCol = getFeaturesCol();
+        final String labelCol = getLabelCol();
+        final String featureType = getFeatureType();
+        final String labelType = getLabelType();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        Table output;
+        if (CATEGORICAL.equals(featureType) && CATEGORICAL.equals(labelType)) {
+            output =
+                    new ChiSqTest()
+                            .setFeaturesCol(featuresCol)
+                            .setLabelCol(labelCol)
+                            .setFlatten(true)
+                            .transform(inputs[0])[0];
+        } else if (CONTINUOUS.equals(featureType) && CATEGORICAL.equals(labelType)) {
+            output =
+                    new ANOVATest()
+                            .setFeaturesCol(featuresCol)
+                            .setLabelCol(labelCol)
+                            .setFlatten(true)
+                            .transform(inputs[0])[0];
+        } else if (CONTINUOUS.equals(featureType) && CONTINUOUS.equals(labelType)) {
+            output =
+                    new FValueTest()
+                            .setFeaturesCol(featuresCol)
+                            .setLabelCol(labelCol)
+                            .setFlatten(true)
+                            .transform(inputs[0])[0];
+        } else {
+            throw new IllegalArgumentException(
+                    String.format(
+                            "Unsupported combination: featureType=%s, labelType=%s.",
+                            featureType, labelType));
+        }
+        DataStream<UnivariateFeatureSelectorModelData> modelData =
+                tEnv.toDataStream(output)
+                        .transform(
+                                "selectIndicesFromPValues",
+                                TypeInformation.of(UnivariateFeatureSelectorModelData.class),
+                                new SelectIndicesFromPValuesOperator(
+                                        getSelectionMode(), getActualSelectionThreshold()))
+                        .setParallelism(1);
+        UnivariateFeatureSelectorModel model =
+                new UnivariateFeatureSelectorModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, getParamMap());
+        return model;
+    }
+
+    private double getActualSelectionThreshold() {
+        Double threshold = getSelectionThreshold();
+        if (threshold == null) {
+            String selectionMode = getSelectionMode();
+            if (NUM_TOP_FEATURES.equals(selectionMode)) {
+                threshold = 50.0;
+            } else if (PERCENTILE.equals(selectionMode)) {
+                threshold = 0.1;
+            } else {
+                threshold = 0.05;
+            }
+        } else {
+            if (NUM_TOP_FEATURES.equals(getSelectionMode())) {
+                Preconditions.checkArgument(
+                        threshold >= 1 && threshold.intValue() == threshold,
+                        "SelectionThreshold needs to be a positive Integer "
+                                + "for selection mode numTopFeatures, but got %s.",
+                        threshold);
+            } else {
+                Preconditions.checkArgument(
+                        threshold >= 0 && threshold <= 1,
+                        "SelectionThreshold needs to be in the range [0, 1] "
+                                + "for selection mode %s, but got %s.",
+                        getSelectionMode(),
+                        threshold);
+            }
+        }
+        return threshold;
+    }
+
+    private static class SelectIndicesFromPValuesOperator
+            extends AbstractStreamOperator<UnivariateFeatureSelectorModelData>
+            implements OneInputStreamOperator<Row, UnivariateFeatureSelectorModelData>,
+                    BoundedOneInput {
+        private String selectionMode;
+        private double threshold;
+
+        private List<Tuple2<Double, Integer>> pValuesAndIndices;
+        private ListState<List<Tuple2<Double, Integer>>> pValuesAndIndicesState;

Review Comment:
   `ListState<Tuple2<Double, Integer>>` might be better.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java:
##########
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which transforms data using the model data computed by {@link UnivariateFeatureSelector}.
+ */
+public class UnivariateFeatureSelectorModel
+        implements Model<UnivariateFeatureSelectorModel>,
+                UnivariateFeatureSelectorModelParams<UnivariateFeatureSelectorModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public UnivariateFeatureSelectorModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public UnivariateFeatureSelectorModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<UnivariateFeatureSelectorModelData> modelData =
+                UnivariateFeatureSelectorModelData.getModelDataStream(modelDataTable);
+
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), DenseVectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, modelData),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictOutputFunction(getFeaturesCol(), broadcastModelKey),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+
+        private final String inputCol;
+        private final String broadcastKey;
+        private int[] indices = null;
+
+        public PredictOutputFunction(String inputCol, String broadcastKey) {
+            this.inputCol = inputCol;
+            this.broadcastKey = broadcastKey;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (indices == null) {
+                UnivariateFeatureSelectorModelData modelData =
+                        (UnivariateFeatureSelectorModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+                indices = modelData.indices;
+            }
+
+            if (indices.length == 0) {
+                return Row.join(row, Row.of(Vectors.dense()));
+            } else {
+                Vector inputVec = ((Vector) row.getField(inputCol));
+                Preconditions.checkArgument(
+                        inputVec.size() > indices[indices.length - 1],
+                        "Input %s features, but UnivariateFeatureSelector is "
+                                + "expecting at least %s features as input.",
+                        inputVec.size(),
+                        indices[indices.length - 1] + 1);
+                DenseVector outputVec = new DenseVector(indices.length);

Review Comment:
   It might be better to generate a sparse vector when the input is a sparse vector.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java:
##########
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which transforms data using the model data computed by {@link UnivariateFeatureSelector}.
+ */
+public class UnivariateFeatureSelectorModel
+        implements Model<UnivariateFeatureSelectorModel>,
+                UnivariateFeatureSelectorModelParams<UnivariateFeatureSelectorModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public UnivariateFeatureSelectorModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public UnivariateFeatureSelectorModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<UnivariateFeatureSelectorModelData> modelData =
+                UnivariateFeatureSelectorModelData.getModelDataStream(modelDataTable);
+
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), DenseVectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+        DataStream<Row> output =

Review Comment:
   Let's improve the names of the variables. It seems weird to have `inputs` as `Table[]` and `output` as `DataStream`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelector.java:
##########
@@ -0,0 +1,301 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.stats.anovatest.ANOVATest;
+import org.apache.flink.ml.stats.chisqtest.ChiSqTest;
+import org.apache.flink.ml.stats.fvaluetest.FValueTest;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.IntStream;
+
+/**
+ * An Estimator which selects features based on univariate statistical tests against labels.
+ *
+ * <p>Currently, Flink supports three Univariate Feature Selectors: chi-squared, ANOVA F-test and
+ * F-value. User can choose Univariate Feature Selector by setting `featureType` and `labelType`,
+ * and Flink will pick the score function based on the specified `featureType` and `labelType`.
+ *
+ * <p>The following combination of `featureType` and `labelType` are supported:
+ *
+ * <ul>
+ *   <li>`featureType` `categorical` and `labelType` `categorical`: Flink uses chi-squared, i.e.
+ *       chi2 in sklearn.
+ *   <li>`featureType` `continuous` and `labelType` `categorical`: Flink uses ANOVA F-test,
+ *       f_classif in sklearn.
+ *   <li>`featureType` `continuous` and `labelType` `continuous`: Flink uses F-value, i.e.
+ *       f_regression in sklearn.
+ * </ul>
+ *
+ * <p>The `UnivariateFeatureSelector` supports different selection modes:
+ *
+ * <ul>
+ *   <li>numTopFeatures: chooses a fixed number of top features according to a hypothesis.
+ *   <li>percentile: similar to numTopFeatures but chooses a fraction of all features instead of a
+ *       fixed number.
+ *   <li>fpr: chooses all features whose p-value are below a threshold, thus controlling the false
+ *       positive rate of selection.
+ *   <li>fdr: uses the <a
+ *       href="https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure">
+ *       Benjamini-Hochberg procedure</a> to choose all features whose false discovery rate is below
+ *       a threshold.
+ *   <li>fwe: chooses all features whose p-values are below a threshold. The threshold is scaled by
+ *       1/numFeatures, thus controlling the family-wise error rate of selection.
+ * </ul>
+ *
+ * <p>By default, the selection mode is `numTopFeatures`.
+ */
+public class UnivariateFeatureSelector
+        implements Estimator<UnivariateFeatureSelector, UnivariateFeatureSelectorModel>,
+                UnivariateFeatureSelectorParams<UnivariateFeatureSelector> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public UnivariateFeatureSelector() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public UnivariateFeatureSelectorModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String featuresCol = getFeaturesCol();
+        final String labelCol = getLabelCol();
+        final String featureType = getFeatureType();
+        final String labelType = getLabelType();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        Table output;
+        if (CATEGORICAL.equals(featureType) && CATEGORICAL.equals(labelType)) {
+            output =
+                    new ChiSqTest()
+                            .setFeaturesCol(featuresCol)
+                            .setLabelCol(labelCol)
+                            .setFlatten(true)
+                            .transform(inputs[0])[0];
+        } else if (CONTINUOUS.equals(featureType) && CATEGORICAL.equals(labelType)) {
+            output =
+                    new ANOVATest()
+                            .setFeaturesCol(featuresCol)
+                            .setLabelCol(labelCol)
+                            .setFlatten(true)
+                            .transform(inputs[0])[0];
+        } else if (CONTINUOUS.equals(featureType) && CONTINUOUS.equals(labelType)) {
+            output =
+                    new FValueTest()
+                            .setFeaturesCol(featuresCol)
+                            .setLabelCol(labelCol)
+                            .setFlatten(true)
+                            .transform(inputs[0])[0];
+        } else {
+            throw new IllegalArgumentException(
+                    String.format(
+                            "Unsupported combination: featureType=%s, labelType=%s.",
+                            featureType, labelType));
+        }
+        DataStream<UnivariateFeatureSelectorModelData> modelData =
+                tEnv.toDataStream(output)
+                        .transform(
+                                "selectIndicesFromPValues",
+                                TypeInformation.of(UnivariateFeatureSelectorModelData.class),
+                                new SelectIndicesFromPValuesOperator(
+                                        getSelectionMode(), getActualSelectionThreshold()))
+                        .setParallelism(1);
+        UnivariateFeatureSelectorModel model =
+                new UnivariateFeatureSelectorModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, getParamMap());
+        return model;
+    }
+
+    private double getActualSelectionThreshold() {
+        Double threshold = getSelectionThreshold();
+        if (threshold == null) {
+            String selectionMode = getSelectionMode();
+            if (NUM_TOP_FEATURES.equals(selectionMode)) {
+                threshold = 50.0;
+            } else if (PERCENTILE.equals(selectionMode)) {
+                threshold = 0.1;
+            } else {
+                threshold = 0.05;
+            }
+        } else {
+            if (NUM_TOP_FEATURES.equals(getSelectionMode())) {
+                Preconditions.checkArgument(
+                        threshold >= 1 && threshold.intValue() == threshold,
+                        "SelectionThreshold needs to be a positive Integer "
+                                + "for selection mode numTopFeatures, but got %s.",
+                        threshold);
+            } else {
+                Preconditions.checkArgument(
+                        threshold >= 0 && threshold <= 1,
+                        "SelectionThreshold needs to be in the range [0, 1] "
+                                + "for selection mode %s, but got %s.",
+                        getSelectionMode(),
+                        threshold);
+            }
+        }
+        return threshold;
+    }
+
+    private static class SelectIndicesFromPValuesOperator
+            extends AbstractStreamOperator<UnivariateFeatureSelectorModelData>
+            implements OneInputStreamOperator<Row, UnivariateFeatureSelectorModelData>,
+                    BoundedOneInput {
+        private String selectionMode;
+        private double threshold;
+
+        private List<Tuple2<Double, Integer>> pValuesAndIndices;
+        private ListState<List<Tuple2<Double, Integer>>> pValuesAndIndicesState;
+
+        public SelectIndicesFromPValuesOperator(String selectionMode, double threshold) {
+            this.selectionMode = selectionMode;
+            this.threshold = threshold;
+        }
+
+        @Override
+        public void endInput() {
+            List<Integer> indices = new ArrayList<>();
+
+            switch (selectionMode) {
+                case NUM_TOP_FEATURES:
+                    pValuesAndIndices.sort(Comparator.comparing(t -> t.f0));

Review Comment:
   Let's add comparisons to indices when pValues are the same. And let's add test cases to verify the selected indices in this situation.



##########
flink-ml-python/pyflink/ml/lib/feature/univariatefeatureselector.py:
##########
@@ -0,0 +1,215 @@
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+import typing
+from pyflink.ml.core.wrapper import JavaWithParams
+
+from pyflink.ml.lib.param import HasFeaturesCol, HasLabelCol, HasOutputCol
+
+from pyflink.ml.core.param import StringParam, ParamValidators, FloatParam
+
+from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator

Review Comment:
   It might be better to reformat the code to reorder the imports.



##########
flink-ml-python/pyflink/ml/lib/feature/univariatefeatureselector.py:
##########
@@ -0,0 +1,215 @@
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+import typing
+from pyflink.ml.core.wrapper import JavaWithParams
+
+from pyflink.ml.lib.param import HasFeaturesCol, HasLabelCol, HasOutputCol
+
+from pyflink.ml.core.param import StringParam, ParamValidators, FloatParam
+
+from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator
+
+
+class _UnivariateFeatureSelectorModelParams(
+    JavaWithParams,
+    HasFeaturesCol,
+    HasLabelCol,
+    HasOutputCol
+):
+    """
+    Params for :class `UnivariateFeatureSelectorModel`.
+    """
+    def __init__(self, java_params):
+        super(_UnivariateFeatureSelectorModelParams, self).__init__(java_params)
+
+
+class _UnivariateFeatureSelectorParams(_UnivariateFeatureSelectorModelParams):
+    """
+    Params for :class `UnivariateFeatureSelector`.
+    """
+
+    """
+    Supported options of the feature type.
+
+    <ul>
+        <li>categorical: the features are categorical data.
+        <li>continuous: the features are continuous data.
+    </ul>
+    """
+    FEATURE_TYPE: StringParam = StringParam(
+        "feature_type",
+        "The feature type.",
+        None,
+        ParamValidators.in_array(['categorical', 'continuous']))
+
+    """
+    Supported options of the label type.
+
+    <ul>
+        <li>categorical: the label is categorical data.
+        <li>continuous: the label is continuous data.
+    </ul>
+    """
+    LABEL_TYPE: StringParam = StringParam(
+        "label_type",
+        "The label type.",
+        None,
+        ParamValidators.in_array(['categorical', 'continuous']))
+
+    """
+    Supported options of the feature selection mode.
+
+    <ul>
+        <li>numTopFeatures: chooses a fixed number of top features according to a hypothesis.
+        <li>percentile: similar to numTopFeatures but chooses a fraction of all features
+            instead of a fixed number.
+        <li>fpr: chooses all features whose p-value are below a threshold, thus controlling the
+            false positive rate of selection.
+        <li>fdr: uses the <ahref="https://en.wikipedia.org/wiki/False_discovery_rate#
+            Benjamini.E2.80.93Hochberg_procedure">Benjamini-Hochberg procedure</a> to choose
+            all features whose false discovery rate is below a threshold.
+        <li>fwe: chooses all features whose p-values are below a threshold. The threshold is
+            scaled by 1/numFeatures, thus controlling the family-wise error rate of selection.
+    </ul>
+    """
+    SELECTION_MODE: StringParam = StringParam(
+        "selection_mode",
+        "The feature selection mode.",
+        "numTopFeatures",
+        ParamValidators.in_array(['numTopFeatures', 'percentile', 'fpr', 'fdr', 'fwe']))

Review Comment:
   This validator can be removed. Same for other python params.



##########
docs/content/docs/operators/feature/univariategeatureselector.md:
##########
@@ -0,0 +1,220 @@
+---
+title: "Univariate Feature Selector"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/univariatefeatureselector.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## Univariate Feature Selector
+Univariate Feature Selector is an algorithm that selects features based on 
+univariate statistical tests against labels.
+
+Currently, Flink supports three Univariate Feature Selectors: chi-squared, 
+ANOVA F-test and F-value. User can choose Univariate Feature Selector by 
+setting `featureType` and `labelType`, and Flink will pick the score function
+based on the specified `featureType` and `labelType`.
+
+The following combination of `featureType` and `labelType` are supported:
+
+<ul>
+    <li>`featureType` `categorical` and `labelType` `categorical`: Flink uses 
+        chi-squared, i.e. chi2 in sklearn.
+    <li>`featureType` `continuous` and `labelType` `categorical`: Flink uses 
+        ANOVA F-test, f_classif in sklearn.
+    <li>`featureType` `continuous` and `labelType` `continuous`: Flink uses 
+        F-value, i.e. f_regression in sklearn.
+</ul>
+

Review Comment:
   It is enough to write grammar like follows in markdown.
   ```markdown
   - `featureType` `categorical` and `labelType` `categorical`: Flink uses chi-squared, i.e. chi2 in sklearn.
   - `featureType` `continuous` and `labelType` `categorical`: Flink uses ANOVA F-test, f_classif in sklearn.
   - `featureType` `continuous` and `labelType` `continuous`: Flink uses F-value, i.e. f_regression in sklearn.
   ```



##########
docs/content/docs/operators/feature/univariategeatureselector.md:
##########
@@ -0,0 +1,220 @@
+---
+title: "Univariate Feature Selector"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/univariatefeatureselector.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## Univariate Feature Selector
+Univariate Feature Selector is an algorithm that selects features based on 
+univariate statistical tests against labels.
+
+Currently, Flink supports three Univariate Feature Selectors: chi-squared, 
+ANOVA F-test and F-value. User can choose Univariate Feature Selector by 
+setting `featureType` and `labelType`, and Flink will pick the score function
+based on the specified `featureType` and `labelType`.
+
+The following combination of `featureType` and `labelType` are supported:
+
+<ul>
+    <li>`featureType` `categorical` and `labelType` `categorical`: Flink uses 
+        chi-squared, i.e. chi2 in sklearn.
+    <li>`featureType` `continuous` and `labelType` `categorical`: Flink uses 
+        ANOVA F-test, f_classif in sklearn.
+    <li>`featureType` `continuous` and `labelType` `continuous`: Flink uses 
+        F-value, i.e. f_regression in sklearn.
+</ul>
+
+Univariate Feature Selector supports different selection modes:
+
+<ul>
+    <li>numTopFeatures: chooses a fixed number of top features according to a 
+        hypothesis.
+    <li>percentile: similar to numTopFeatures but chooses a fraction of all 
+        features instead of a fixed number.
+    <li>fpr: chooses all features whose p-value are below a threshold, thus 
+        controlling the false positive rate of selection.
+    <li>fdr: uses the <a href="https://en.wikipedia.org/wiki/False_discovery_rate#
+        Benjamini.E2.80.93Hochberg_procedure">Benjamini-Hochberg procedure</a> to 
+        choose all features whose false discovery rate is below a threshold.
+    <li>fwe: chooses all features whose p-values are below a threshold. The 
+        threshold is scaled by 1/numFeatures, thus controlling the family-wise 
+        error rate of selection.
+</ul>
+
+By default, the selection mode is `numTopFeatures`.
+
+### Input Columns
+
+| Param name  | Type   | Default      | Description            |
+|:------------|:-------|:-------------|:-----------------------|
+| featuresCol | Vector | `"features"` | Feature vector.        |
+| labelCol    | Number | `"label"`    | Label of the features. |
+
+### Output Columns
+
+| Param name | Type   | Default    | Description        |
+|:-----------|:-------|:-----------|:-------------------|
+| outputCol  | Vector | `"output"` | Selected features. |
+
+### Parameters
+
+Below are the parameters required by `UnivariateFeatureSelectorModel`.
+
+| Key         | Default      | Type   | Required | Description             |
+|-------------|--------------|--------|----------|-------------------------|
+| featuresCol | `"features"` | String | no       | Features column name.   |
+| labelCol    | `"label"`    | String | no       | Label column name.      |
+| outputCol   | `"output"`   | String | no       | Output column name.     |
+
+`UnivariateFeatureSelector` needs parameters above and also below.
+
+| Key                | Default            | Type    | Required | Description                                                                                                                                                                                                                                                                                                                              |
+| ------------------ | ------------------ | ------- | -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| featureType        | `null`             | String  | yes      | The feature type. Supported values: 'categorical', 'continuous'.                                                                                                                                                                                                                                                                         |
+| labelType          | `null`             | String  | yes      | The label type. Supported values: 'categorical', 'continuous'.                                                                                                                                                                                                                                                                           |
+| selectionMode      | `"numTopFeatures"` | String  | no       | The feature selection mode. Supported values: 'numTopFeatures', 'percentile', 'fpr', 'fdr', 'fwe'.                                                                                                                                                                                                                                       |
+| selectionThreshold | `null`             | Number  | no       | The upper bound of the features that selector will select. If not set, it will be replaced with a meaningful value according to different selection modes at runtime. When the mode is numTopFeatures, it will be replaced with 50; when the mode is percentile, it will be replaced with 0.1; otherwise, it will be replaced with 0.05. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector;
+import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/**
+ * Simple program that trains a {@link UnivariateFeatureSelector} model and uses it for feature
+ * selection.
+ */
+public class UnivariateFeatureSelectorExample {
+    public static void main(String[] args) {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        // Generates input training and prediction data.
+        DataStream<Row> trainStream =
+                env.fromElements(
+                        Row.of(Vectors.dense(1.7, 4.4, 7.6, 5.8, 9.6, 2.3), 3.0),
+                        Row.of(Vectors.dense(8.8, 7.3, 5.7, 7.3, 2.2, 4.1), 2.0),
+                        Row.of(Vectors.dense(1.2, 9.5, 2.5, 3.1, 8.7, 2.5), 1.0),
+                        Row.of(Vectors.dense(3.7, 9.2, 6.1, 4.1, 7.5, 3.8), 2.0),
+                        Row.of(Vectors.dense(8.9, 5.2, 7.8, 8.3, 5.2, 3.0), 4.0),
+                        Row.of(Vectors.dense(7.9, 8.5, 9.2, 4.0, 9.4, 2.1), 4.0));
+        Table trainTable = tEnv.fromDataStream(trainStream).as("features", "label");
+
+        // Creates a UnivariateFeatureSelector object and initializes its parameters.
+        UnivariateFeatureSelector univariateFeatureSelector =
+                new UnivariateFeatureSelector()
+                        .setFeatureType("continuous")
+                        .setLabelType("categorical")
+                        .setSelectionThreshold(1);
+
+        // Trains the UnivariateFeatureSelector model.
+        UnivariateFeatureSelectorModel model = univariateFeatureSelector.fit(trainTable);
+
+        // Uses the UnivariateFeatureSelector model for predictions.
+        Table outputTable = model.transform(trainTable)[0];
+
+        // Extracts and displays the results.
+        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+            Row row = it.next();
+            DenseVector inputValue =
+                    (DenseVector) row.getField(univariateFeatureSelector.getFeaturesCol());
+            DenseVector outputValue =
+                    (DenseVector) row.getField(univariateFeatureSelector.getOutputCol());
+            System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue);
+        }
+    }
+}
+
+```
+
+{{< /tab>}}
+
+{{< tab "Python">}}
+
+```python
+# Simple program that creates a UnivariateFeatureSelector instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.univariatefeatureselector import UnivariateFeatureSelector
+from pyflink.table import StreamTableEnvironment
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+
+env = StreamExecutionEnvironment.get_execution_environment()
+
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input training and prediction data.
+input_table = t_env.from_data_stream(
+    env.from_collection([
+        (Vectors.dense(1.7, 4.4, 7.6, 5.8, 9.6, 2.3), 3.0,),
+        (Vectors.dense(8.8, 7.3, 5.7, 7.3, 2.2, 4.1), 2.0,),
+        (Vectors.dense(1.2, 9.5, 2.5, 3.1, 8.7, 2.5), 1.0,),
+        (Vectors.dense(3.7, 9.2, 6.1, 4.1, 7.5, 3.8), 2.0,),
+        (Vectors.dense(8.9, 5.2, 7.8, 8.3, 5.2, 3.0), 4.0,),
+        (Vectors.dense(7.9, 8.5, 9.2, 4.0, 9.4, 2.1), 4.0,),
+    ],
+        type_info=Types.ROW_NAMED(
+            ['features', 'label'],
+            [DenseVectorTypeInfo(), Types.FLOAT()])
+    ))
+
+# Creates an UnivariateFeatureSelector object and initializes its parameters.
+univariate_feature_selector = UnivariateFeatureSelector() \
+    .set_features_col('features') \
+    .set_label_col('label') \

Review Comment:
   They python and java example is not aligned. Python example calls `set_label_col` while Java example does not.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java:
##########
@@ -0,0 +1,771 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector;
+import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorParams.CATEGORICAL;
+import static org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorParams.CONTINUOUS;
+import static org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorParams.FPR;
+import static org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorParams.NUM_TOP_FEATURES;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+
+/** Tests {@link UnivariateFeatureSelector} and {@link UnivariateFeatureSelectorModel}. */
+public class UnivariateFeatureSelectorTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table inputChiSqTable;
+    private Table inputANOVATable;
+    private Table inputFValueTable;
+
+    private static final double EPS = 1.0e-5;
+
+    private UnivariateFeatureSelector selectorWithChiSqTest;
+    private UnivariateFeatureSelector selectorWithANOVATest;
+    private UnivariateFeatureSelector selectorWithFValueTest;
+
+    private static final List<Row> INPUT_CHISQ_DATA =
+            Arrays.asList(
+                    Row.of(0.0, Vectors.dense(6.0, 7.0, 0.0, 7.0, 6.0, 0.0)),
+                    Row.of(1.0, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0, 9.0)),
+                    Row.of(1.0, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0)),
+                    Row.of(1.0, Vectors.dense(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)),
+                    Row.of(2.0, Vectors.dense(8.0, 9.0, 6.0, 5.0, 4.0, 4.0)),
+                    Row.of(2.0, Vectors.dense(8.0, 9.0, 6.0, 4.0, 0.0, 0.0)));
+
+    private static final List<Row> INPUT_ANOVA_DATA =
+            Arrays.asList(
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    4.65415496e-03,
+                                    1.03550567e-01,
+                                    -1.17358140e+00,
+                                    1.61408773e-01,
+                                    3.92492111e-01,
+                                    7.31240882e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    -9.01651741e-01,
+                                    -5.28905302e-01,
+                                    1.27636785e+00,
+                                    7.02154563e-01,
+                                    6.21348351e-01,
+                                    1.88397353e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    3.85692159e-01,
+                                    -9.04639637e-01,
+                                    5.09782604e-02,
+                                    8.40043971e-01,
+                                    7.45977857e-01,
+                                    8.78402288e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    1.36264353e+00,
+                                    2.62454094e-01,
+                                    7.96306202e-01,
+                                    6.14948000e-01,
+                                    7.44948187e-01,
+                                    9.74034830e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    9.65874070e-01,
+                                    2.52773665e+00,
+                                    -2.19380094e+00,
+                                    2.33408080e-01,
+                                    1.86340919e-01,
+                                    8.23390433e-01)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.12324305e+01,
+                                    -2.77121515e-01,
+                                    1.12740513e-01,
+                                    2.35184013e-01,
+                                    3.46668895e-01,
+                                    9.38500782e-02)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.06195839e+01,
+                                    -1.82891238e+00,
+                                    2.25085601e-01,
+                                    9.09979851e-01,
+                                    6.80257535e-02,
+                                    8.24017480e-01)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.12806837e+01,
+                                    1.30686889e+00,
+                                    9.32839108e-02,
+                                    3.49784755e-01,
+                                    1.71322408e-02,
+                                    7.48465194e-02)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    9.98689462e+00,
+                                    9.50808938e-01,
+                                    -2.90786359e-01,
+                                    2.31253009e-01,
+                                    7.46270968e-01,
+                                    1.60308169e-01)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.08428551e+01,
+                                    -1.02749936e+00,
+                                    1.73951508e-01,
+                                    8.92482744e-02,
+                                    1.42651730e-01,
+                                    7.66751625e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    -1.98641448e+00,
+                                    1.12811990e+01,
+                                    -2.35246756e-01,
+                                    8.22809049e-01,
+                                    3.26739456e-01,
+                                    7.88268404e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    -6.09864090e-01,
+                                    1.07346276e+01,
+                                    -2.18805509e-01,
+                                    7.33931213e-01,
+                                    1.42554396e-01,
+                                    7.11225605e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    -1.58481268e+00,
+                                    9.19364039e+00,
+                                    -5.87490459e-02,
+                                    2.51532056e-01,
+                                    2.82729807e-01,
+                                    7.16245686e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    -2.50949277e-01,
+                                    1.12815254e+01,
+                                    -6.94806734e-01,
+                                    5.93898886e-01,
+                                    5.68425656e-01,
+                                    8.49762330e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    7.63485129e-01,
+                                    1.02605138e+01,
+                                    1.32617719e+00,
+                                    5.49682879e-01,
+                                    8.59931442e-01,
+                                    4.88677978e-02)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    9.34900015e-01,
+                                    4.11379043e-01,
+                                    8.65010205e+00,
+                                    9.23509168e-01,
+                                    1.16995043e-01,
+                                    5.91894106e-03)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    4.73734933e-01,
+                                    -1.48321181e+00,
+                                    9.73349621e+00,
+                                    4.09421563e-01,
+                                    5.09375719e-01,
+                                    5.93157850e-01)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    3.41470679e-01,
+                                    -6.88972582e-01,
+                                    9.60347938e+00,
+                                    3.62654055e-01,
+                                    2.43437468e-01,
+                                    7.13052838e-01)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    -5.29614251e-01,
+                                    -1.39262856e+00,
+                                    1.01354144e+01,
+                                    8.24123861e-01,
+                                    5.84074506e-01,
+                                    6.54461558e-01)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    -2.99454508e-01,
+                                    2.20457263e+00,
+                                    1.14586015e+01,
+                                    5.16336729e-01,
+                                    9.99776159e-01,
+                                    3.15769738e-01)));
+
+    private static final List<Row> INPUT_FVALUE_DATA =
+            Arrays.asList(
+                    Row.of(
+                            0.52516321,
+                            Vectors.dense(
+                                    0.19151945,
+                                    0.62210877,
+                                    0.43772774,
+                                    0.78535858,
+                                    0.77997581,
+                                    0.27259261)),
+                    Row.of(
+                            0.88275782,
+                            Vectors.dense(
+                                    0.27646426,
+                                    0.80187218,
+                                    0.95813935,
+                                    0.87593263,
+                                    0.35781727,
+                                    0.50099513)),
+                    Row.of(
+                            0.67524507,
+                            Vectors.dense(
+                                    0.68346294,
+                                    0.71270203,
+                                    0.37025075,
+                                    0.56119619,
+                                    0.50308317,
+                                    0.01376845)),
+                    Row.of(
+                            0.76734745,
+                            Vectors.dense(
+                                    0.77282662,
+                                    0.88264119,
+                                    0.36488598,
+                                    0.61539618,
+                                    0.07538124,
+                                    0.36882401)),
+                    Row.of(
+                            0.73909458,
+                            Vectors.dense(
+                                    0.9331401,
+                                    0.65137814,
+                                    0.39720258,
+                                    0.78873014,
+                                    0.31683612,
+                                    0.56809865)),
+                    Row.of(
+                            0.83628141,
+                            Vectors.dense(
+                                    0.86912739,
+                                    0.43617342,
+                                    0.80214764,
+                                    0.14376682,
+                                    0.70426097,
+                                    0.70458131)),
+                    Row.of(
+                            0.65665506,
+                            Vectors.dense(
+                                    0.21879211,
+                                    0.92486763,
+                                    0.44214076,
+                                    0.90931596,
+                                    0.05980922,
+                                    0.18428708)),
+                    Row.of(
+                            0.58147135,
+                            Vectors.dense(
+                                    0.04735528,
+                                    0.67488094,
+                                    0.59462478,
+                                    0.53331016,
+                                    0.04332406,
+                                    0.56143308)),
+                    Row.of(
+                            0.35603443,
+                            Vectors.dense(
+                                    0.32966845,
+                                    0.50296683,
+                                    0.11189432,
+                                    0.60719371,
+                                    0.56594464,
+                                    0.00676406)),
+                    Row.of(
+                            0.94534373,
+                            Vectors.dense(
+                                    0.61744171,
+                                    0.91212289,
+                                    0.79052413,
+                                    0.99208147,
+                                    0.95880176,
+                                    0.79196414)),
+                    Row.of(
+                            0.57458887,
+                            Vectors.dense(
+                                    0.28525096,
+                                    0.62491671,
+                                    0.4780938,
+                                    0.19567518,
+                                    0.38231745,
+                                    0.05387369)),
+                    Row.of(
+                            0.59026777,
+                            Vectors.dense(
+                                    0.45164841,
+                                    0.98200474,
+                                    0.1239427,
+                                    0.1193809,
+                                    0.73852306,
+                                    0.58730363)),
+                    Row.of(
+                            0.29894977,
+                            Vectors.dense(
+                                    0.47163253,
+                                    0.10712682,
+                                    0.22921857,
+                                    0.89996519,
+                                    0.41675354,
+                                    0.53585166)),
+                    Row.of(
+                            0.34056582,
+                            Vectors.dense(
+                                    0.00620852,
+                                    0.30064171,
+                                    0.43689317,
+                                    0.612149,
+                                    0.91819808,
+                                    0.62573667)),
+                    Row.of(
+                            0.64476446,
+                            Vectors.dense(
+                                    0.70599757,
+                                    0.14983372,
+                                    0.74606341,
+                                    0.83100699,
+                                    0.63372577,
+                                    0.43830988)),
+                    Row.of(
+                            0.53724782,
+                            Vectors.dense(
+                                    0.15257277,
+                                    0.56840962,
+                                    0.52822428,
+                                    0.95142876,
+                                    0.48035918,
+                                    0.50255956)),
+                    Row.of(
+                            0.5173021,
+                            Vectors.dense(
+                                    0.53687819,
+                                    0.81920207,
+                                    0.05711564,
+                                    0.66942174,
+                                    0.76711663,
+                                    0.70811536)),
+                    Row.of(
+                            0.94508275,
+                            Vectors.dense(
+                                    0.79686718,
+                                    0.55776083,
+                                    0.96583653,
+                                    0.1471569,
+                                    0.029647,
+                                    0.59389349)),
+                    Row.of(
+                            0.57739736,
+                            Vectors.dense(
+                                    0.1140657,
+                                    0.95080985,
+                                    0.96583653,
+                                    0.19361869,
+                                    0.45781165,
+                                    0.92040257)),
+                    Row.of(
+                            0.53877145,
+                            Vectors.dense(
+                                    0.87906916,
+                                    0.25261576,
+                                    0.34800879,
+                                    0.18258873,
+                                    0.90179605,
+                                    0.70652816)));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        env.getConfig().enableObjectReuse();
+        tEnv = StreamTableEnvironment.create(env);
+
+        selectorWithChiSqTest =
+                new UnivariateFeatureSelector()
+                        .setFeatureType("categorical")
+                        .setLabelType("categorical");
+        selectorWithANOVATest =
+                new UnivariateFeatureSelector()
+                        .setFeatureType("continuous")
+                        .setLabelType("categorical");
+        selectorWithFValueTest =
+                new UnivariateFeatureSelector()
+                        .setFeatureType("continuous")
+                        .setLabelType("continuous");
+        inputChiSqTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_CHISQ_DATA)).as("label", "features");
+        inputANOVATable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_ANOVA_DATA)).as("label", "features");
+        inputFValueTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_FVALUE_DATA)).as("label", "features");
+    }
+
+    private void transformAndVerify(
+            UnivariateFeatureSelector selector, Table table, int... expectedIndices)
+            throws Exception {
+        UnivariateFeatureSelectorModel model = selector.fit(table);
+        Table output = model.transform(table)[0];
+        verify(output, expectedIndices);
+    }
+
+    private void verify(Table table, int... expectedIndices) throws Exception {

Review Comment:
   It might be better to modify this method's signature into follows.
   ```java
   private void verifyOutputResult(Table outputTable, int... expectedIndices) throws Exception
   ```



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java:
##########
@@ -0,0 +1,771 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector;
+import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorParams.CATEGORICAL;
+import static org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorParams.CONTINUOUS;
+import static org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorParams.FPR;
+import static org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorParams.NUM_TOP_FEATURES;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+
+/** Tests {@link UnivariateFeatureSelector} and {@link UnivariateFeatureSelectorModel}. */
+public class UnivariateFeatureSelectorTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table inputChiSqTable;
+    private Table inputANOVATable;
+    private Table inputFValueTable;
+
+    private static final double EPS = 1.0e-5;
+
+    private UnivariateFeatureSelector selectorWithChiSqTest;
+    private UnivariateFeatureSelector selectorWithANOVATest;
+    private UnivariateFeatureSelector selectorWithFValueTest;
+
+    private static final List<Row> INPUT_CHISQ_DATA =
+            Arrays.asList(
+                    Row.of(0.0, Vectors.dense(6.0, 7.0, 0.0, 7.0, 6.0, 0.0)),
+                    Row.of(1.0, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0, 9.0)),
+                    Row.of(1.0, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0)),
+                    Row.of(1.0, Vectors.dense(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)),
+                    Row.of(2.0, Vectors.dense(8.0, 9.0, 6.0, 5.0, 4.0, 4.0)),
+                    Row.of(2.0, Vectors.dense(8.0, 9.0, 6.0, 4.0, 0.0, 0.0)));
+
+    private static final List<Row> INPUT_ANOVA_DATA =
+            Arrays.asList(
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    4.65415496e-03,
+                                    1.03550567e-01,
+                                    -1.17358140e+00,
+                                    1.61408773e-01,
+                                    3.92492111e-01,
+                                    7.31240882e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    -9.01651741e-01,
+                                    -5.28905302e-01,
+                                    1.27636785e+00,
+                                    7.02154563e-01,
+                                    6.21348351e-01,
+                                    1.88397353e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    3.85692159e-01,
+                                    -9.04639637e-01,
+                                    5.09782604e-02,
+                                    8.40043971e-01,
+                                    7.45977857e-01,
+                                    8.78402288e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    1.36264353e+00,
+                                    2.62454094e-01,
+                                    7.96306202e-01,
+                                    6.14948000e-01,
+                                    7.44948187e-01,
+                                    9.74034830e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    9.65874070e-01,
+                                    2.52773665e+00,
+                                    -2.19380094e+00,
+                                    2.33408080e-01,
+                                    1.86340919e-01,
+                                    8.23390433e-01)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.12324305e+01,
+                                    -2.77121515e-01,
+                                    1.12740513e-01,
+                                    2.35184013e-01,
+                                    3.46668895e-01,
+                                    9.38500782e-02)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.06195839e+01,
+                                    -1.82891238e+00,
+                                    2.25085601e-01,
+                                    9.09979851e-01,
+                                    6.80257535e-02,
+                                    8.24017480e-01)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.12806837e+01,
+                                    1.30686889e+00,
+                                    9.32839108e-02,
+                                    3.49784755e-01,
+                                    1.71322408e-02,
+                                    7.48465194e-02)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    9.98689462e+00,
+                                    9.50808938e-01,
+                                    -2.90786359e-01,
+                                    2.31253009e-01,
+                                    7.46270968e-01,
+                                    1.60308169e-01)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.08428551e+01,
+                                    -1.02749936e+00,
+                                    1.73951508e-01,
+                                    8.92482744e-02,
+                                    1.42651730e-01,
+                                    7.66751625e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    -1.98641448e+00,
+                                    1.12811990e+01,
+                                    -2.35246756e-01,
+                                    8.22809049e-01,
+                                    3.26739456e-01,
+                                    7.88268404e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    -6.09864090e-01,
+                                    1.07346276e+01,
+                                    -2.18805509e-01,
+                                    7.33931213e-01,
+                                    1.42554396e-01,
+                                    7.11225605e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    -1.58481268e+00,
+                                    9.19364039e+00,
+                                    -5.87490459e-02,
+                                    2.51532056e-01,
+                                    2.82729807e-01,
+                                    7.16245686e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    -2.50949277e-01,
+                                    1.12815254e+01,
+                                    -6.94806734e-01,
+                                    5.93898886e-01,
+                                    5.68425656e-01,
+                                    8.49762330e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    7.63485129e-01,
+                                    1.02605138e+01,
+                                    1.32617719e+00,
+                                    5.49682879e-01,
+                                    8.59931442e-01,
+                                    4.88677978e-02)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    9.34900015e-01,
+                                    4.11379043e-01,
+                                    8.65010205e+00,
+                                    9.23509168e-01,
+                                    1.16995043e-01,
+                                    5.91894106e-03)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    4.73734933e-01,
+                                    -1.48321181e+00,
+                                    9.73349621e+00,
+                                    4.09421563e-01,
+                                    5.09375719e-01,
+                                    5.93157850e-01)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    3.41470679e-01,
+                                    -6.88972582e-01,
+                                    9.60347938e+00,
+                                    3.62654055e-01,
+                                    2.43437468e-01,
+                                    7.13052838e-01)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    -5.29614251e-01,
+                                    -1.39262856e+00,
+                                    1.01354144e+01,
+                                    8.24123861e-01,
+                                    5.84074506e-01,
+                                    6.54461558e-01)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    -2.99454508e-01,
+                                    2.20457263e+00,
+                                    1.14586015e+01,
+                                    5.16336729e-01,
+                                    9.99776159e-01,
+                                    3.15769738e-01)));
+
+    private static final List<Row> INPUT_FVALUE_DATA =
+            Arrays.asList(
+                    Row.of(
+                            0.52516321,
+                            Vectors.dense(
+                                    0.19151945,
+                                    0.62210877,
+                                    0.43772774,
+                                    0.78535858,
+                                    0.77997581,
+                                    0.27259261)),
+                    Row.of(
+                            0.88275782,
+                            Vectors.dense(
+                                    0.27646426,
+                                    0.80187218,
+                                    0.95813935,
+                                    0.87593263,
+                                    0.35781727,
+                                    0.50099513)),
+                    Row.of(
+                            0.67524507,
+                            Vectors.dense(
+                                    0.68346294,
+                                    0.71270203,
+                                    0.37025075,
+                                    0.56119619,
+                                    0.50308317,
+                                    0.01376845)),
+                    Row.of(
+                            0.76734745,
+                            Vectors.dense(
+                                    0.77282662,
+                                    0.88264119,
+                                    0.36488598,
+                                    0.61539618,
+                                    0.07538124,
+                                    0.36882401)),
+                    Row.of(
+                            0.73909458,
+                            Vectors.dense(
+                                    0.9331401,
+                                    0.65137814,
+                                    0.39720258,
+                                    0.78873014,
+                                    0.31683612,
+                                    0.56809865)),
+                    Row.of(
+                            0.83628141,
+                            Vectors.dense(
+                                    0.86912739,
+                                    0.43617342,
+                                    0.80214764,
+                                    0.14376682,
+                                    0.70426097,
+                                    0.70458131)),
+                    Row.of(
+                            0.65665506,
+                            Vectors.dense(
+                                    0.21879211,
+                                    0.92486763,
+                                    0.44214076,
+                                    0.90931596,
+                                    0.05980922,
+                                    0.18428708)),
+                    Row.of(
+                            0.58147135,
+                            Vectors.dense(
+                                    0.04735528,
+                                    0.67488094,
+                                    0.59462478,
+                                    0.53331016,
+                                    0.04332406,
+                                    0.56143308)),
+                    Row.of(
+                            0.35603443,
+                            Vectors.dense(
+                                    0.32966845,
+                                    0.50296683,
+                                    0.11189432,
+                                    0.60719371,
+                                    0.56594464,
+                                    0.00676406)),
+                    Row.of(
+                            0.94534373,
+                            Vectors.dense(
+                                    0.61744171,
+                                    0.91212289,
+                                    0.79052413,
+                                    0.99208147,
+                                    0.95880176,
+                                    0.79196414)),
+                    Row.of(
+                            0.57458887,
+                            Vectors.dense(
+                                    0.28525096,
+                                    0.62491671,
+                                    0.4780938,
+                                    0.19567518,
+                                    0.38231745,
+                                    0.05387369)),
+                    Row.of(
+                            0.59026777,
+                            Vectors.dense(
+                                    0.45164841,
+                                    0.98200474,
+                                    0.1239427,
+                                    0.1193809,
+                                    0.73852306,
+                                    0.58730363)),
+                    Row.of(
+                            0.29894977,
+                            Vectors.dense(
+                                    0.47163253,
+                                    0.10712682,
+                                    0.22921857,
+                                    0.89996519,
+                                    0.41675354,
+                                    0.53585166)),
+                    Row.of(
+                            0.34056582,
+                            Vectors.dense(
+                                    0.00620852,
+                                    0.30064171,
+                                    0.43689317,
+                                    0.612149,
+                                    0.91819808,
+                                    0.62573667)),
+                    Row.of(
+                            0.64476446,
+                            Vectors.dense(
+                                    0.70599757,
+                                    0.14983372,
+                                    0.74606341,
+                                    0.83100699,
+                                    0.63372577,
+                                    0.43830988)),
+                    Row.of(
+                            0.53724782,
+                            Vectors.dense(
+                                    0.15257277,
+                                    0.56840962,
+                                    0.52822428,
+                                    0.95142876,
+                                    0.48035918,
+                                    0.50255956)),
+                    Row.of(
+                            0.5173021,
+                            Vectors.dense(
+                                    0.53687819,
+                                    0.81920207,
+                                    0.05711564,
+                                    0.66942174,
+                                    0.76711663,
+                                    0.70811536)),
+                    Row.of(
+                            0.94508275,
+                            Vectors.dense(
+                                    0.79686718,
+                                    0.55776083,
+                                    0.96583653,
+                                    0.1471569,
+                                    0.029647,
+                                    0.59389349)),
+                    Row.of(
+                            0.57739736,
+                            Vectors.dense(
+                                    0.1140657,
+                                    0.95080985,
+                                    0.96583653,
+                                    0.19361869,
+                                    0.45781165,
+                                    0.92040257)),
+                    Row.of(
+                            0.53877145,
+                            Vectors.dense(
+                                    0.87906916,
+                                    0.25261576,
+                                    0.34800879,
+                                    0.18258873,
+                                    0.90179605,
+                                    0.70652816)));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        env.getConfig().enableObjectReuse();
+        tEnv = StreamTableEnvironment.create(env);
+
+        selectorWithChiSqTest =
+                new UnivariateFeatureSelector()
+                        .setFeatureType("categorical")
+                        .setLabelType("categorical");
+        selectorWithANOVATest =
+                new UnivariateFeatureSelector()
+                        .setFeatureType("continuous")
+                        .setLabelType("categorical");
+        selectorWithFValueTest =
+                new UnivariateFeatureSelector()
+                        .setFeatureType("continuous")
+                        .setLabelType("continuous");
+        inputChiSqTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_CHISQ_DATA)).as("label", "features");
+        inputANOVATable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_ANOVA_DATA)).as("label", "features");
+        inputFValueTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_FVALUE_DATA)).as("label", "features");
+    }
+
+    private void transformAndVerify(
+            UnivariateFeatureSelector selector, Table table, int... expectedIndices)
+            throws Exception {
+        UnivariateFeatureSelectorModel model = selector.fit(table);
+        Table output = model.transform(table)[0];
+        verify(output, expectedIndices);
+    }
+
+    private void verify(Table table, int... expectedIndices) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) table).getTableEnvironment();
+        CloseableIterator<Row> rowIterator = tEnv.toDataStream(table).executeAndCollect();
+        while (rowIterator.hasNext()) {
+            Row row = rowIterator.next();
+            for (int i = 0; i < expectedIndices.length; i++) {
+                assertEquals(
+                        ((Vector) row.getField("features")).get(expectedIndices[i]),
+                        ((Vector) row.getField("output")).get(i),
+                        EPS);
+            }
+        }
+    }
+
+    @Test
+    public void testParam() {
+        UnivariateFeatureSelector selector = new UnivariateFeatureSelector();
+        assertEquals("features", selector.getFeaturesCol());
+        assertEquals("label", selector.getLabelCol());
+        assertEquals("output", selector.getOutputCol());
+        try {
+            selector.getFeatureType();
+            fail();

Review Comment:
   It is not typical practice to test parameters that must be set() before get() in `testParam`.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java:
##########
@@ -0,0 +1,771 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector;
+import org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorParams.CATEGORICAL;
+import static org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorParams.CONTINUOUS;
+import static org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorParams.FPR;
+import static org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorParams.NUM_TOP_FEATURES;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+
+/** Tests {@link UnivariateFeatureSelector} and {@link UnivariateFeatureSelectorModel}. */
+public class UnivariateFeatureSelectorTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table inputChiSqTable;
+    private Table inputANOVATable;
+    private Table inputFValueTable;
+
+    private static final double EPS = 1.0e-5;
+
+    private UnivariateFeatureSelector selectorWithChiSqTest;
+    private UnivariateFeatureSelector selectorWithANOVATest;
+    private UnivariateFeatureSelector selectorWithFValueTest;
+
+    private static final List<Row> INPUT_CHISQ_DATA =
+            Arrays.asList(
+                    Row.of(0.0, Vectors.dense(6.0, 7.0, 0.0, 7.0, 6.0, 0.0)),
+                    Row.of(1.0, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0, 9.0)),
+                    Row.of(1.0, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0)),
+                    Row.of(1.0, Vectors.dense(0.0, 9.0, 8.0, 5.0, 6.0, 4.0)),
+                    Row.of(2.0, Vectors.dense(8.0, 9.0, 6.0, 5.0, 4.0, 4.0)),
+                    Row.of(2.0, Vectors.dense(8.0, 9.0, 6.0, 4.0, 0.0, 0.0)));
+
+    private static final List<Row> INPUT_ANOVA_DATA =
+            Arrays.asList(
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    4.65415496e-03,
+                                    1.03550567e-01,
+                                    -1.17358140e+00,
+                                    1.61408773e-01,
+                                    3.92492111e-01,
+                                    7.31240882e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    -9.01651741e-01,
+                                    -5.28905302e-01,
+                                    1.27636785e+00,
+                                    7.02154563e-01,
+                                    6.21348351e-01,
+                                    1.88397353e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    3.85692159e-01,
+                                    -9.04639637e-01,
+                                    5.09782604e-02,
+                                    8.40043971e-01,
+                                    7.45977857e-01,
+                                    8.78402288e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    1.36264353e+00,
+                                    2.62454094e-01,
+                                    7.96306202e-01,
+                                    6.14948000e-01,
+                                    7.44948187e-01,
+                                    9.74034830e-01)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    9.65874070e-01,
+                                    2.52773665e+00,
+                                    -2.19380094e+00,
+                                    2.33408080e-01,
+                                    1.86340919e-01,
+                                    8.23390433e-01)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.12324305e+01,
+                                    -2.77121515e-01,
+                                    1.12740513e-01,
+                                    2.35184013e-01,
+                                    3.46668895e-01,
+                                    9.38500782e-02)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.06195839e+01,
+                                    -1.82891238e+00,
+                                    2.25085601e-01,
+                                    9.09979851e-01,
+                                    6.80257535e-02,
+                                    8.24017480e-01)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.12806837e+01,
+                                    1.30686889e+00,
+                                    9.32839108e-02,
+                                    3.49784755e-01,
+                                    1.71322408e-02,
+                                    7.48465194e-02)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    9.98689462e+00,
+                                    9.50808938e-01,
+                                    -2.90786359e-01,
+                                    2.31253009e-01,
+                                    7.46270968e-01,
+                                    1.60308169e-01)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    1.08428551e+01,
+                                    -1.02749936e+00,
+                                    1.73951508e-01,
+                                    8.92482744e-02,
+                                    1.42651730e-01,
+                                    7.66751625e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    -1.98641448e+00,
+                                    1.12811990e+01,
+                                    -2.35246756e-01,
+                                    8.22809049e-01,
+                                    3.26739456e-01,
+                                    7.88268404e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    -6.09864090e-01,
+                                    1.07346276e+01,
+                                    -2.18805509e-01,
+                                    7.33931213e-01,
+                                    1.42554396e-01,
+                                    7.11225605e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    -1.58481268e+00,
+                                    9.19364039e+00,
+                                    -5.87490459e-02,
+                                    2.51532056e-01,
+                                    2.82729807e-01,
+                                    7.16245686e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    -2.50949277e-01,
+                                    1.12815254e+01,
+                                    -6.94806734e-01,
+                                    5.93898886e-01,
+                                    5.68425656e-01,
+                                    8.49762330e-01)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    7.63485129e-01,
+                                    1.02605138e+01,
+                                    1.32617719e+00,
+                                    5.49682879e-01,
+                                    8.59931442e-01,
+                                    4.88677978e-02)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    9.34900015e-01,
+                                    4.11379043e-01,
+                                    8.65010205e+00,
+                                    9.23509168e-01,
+                                    1.16995043e-01,
+                                    5.91894106e-03)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    4.73734933e-01,
+                                    -1.48321181e+00,
+                                    9.73349621e+00,
+                                    4.09421563e-01,
+                                    5.09375719e-01,
+                                    5.93157850e-01)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    3.41470679e-01,
+                                    -6.88972582e-01,
+                                    9.60347938e+00,
+                                    3.62654055e-01,
+                                    2.43437468e-01,
+                                    7.13052838e-01)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    -5.29614251e-01,
+                                    -1.39262856e+00,
+                                    1.01354144e+01,
+                                    8.24123861e-01,
+                                    5.84074506e-01,
+                                    6.54461558e-01)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    -2.99454508e-01,
+                                    2.20457263e+00,
+                                    1.14586015e+01,
+                                    5.16336729e-01,
+                                    9.99776159e-01,
+                                    3.15769738e-01)));
+
+    private static final List<Row> INPUT_FVALUE_DATA =
+            Arrays.asList(
+                    Row.of(
+                            0.52516321,
+                            Vectors.dense(
+                                    0.19151945,
+                                    0.62210877,
+                                    0.43772774,
+                                    0.78535858,
+                                    0.77997581,
+                                    0.27259261)),
+                    Row.of(
+                            0.88275782,
+                            Vectors.dense(
+                                    0.27646426,
+                                    0.80187218,
+                                    0.95813935,
+                                    0.87593263,
+                                    0.35781727,
+                                    0.50099513)),
+                    Row.of(
+                            0.67524507,
+                            Vectors.dense(
+                                    0.68346294,
+                                    0.71270203,
+                                    0.37025075,
+                                    0.56119619,
+                                    0.50308317,
+                                    0.01376845)),
+                    Row.of(
+                            0.76734745,
+                            Vectors.dense(
+                                    0.77282662,
+                                    0.88264119,
+                                    0.36488598,
+                                    0.61539618,
+                                    0.07538124,
+                                    0.36882401)),
+                    Row.of(
+                            0.73909458,
+                            Vectors.dense(
+                                    0.9331401,
+                                    0.65137814,
+                                    0.39720258,
+                                    0.78873014,
+                                    0.31683612,
+                                    0.56809865)),
+                    Row.of(
+                            0.83628141,
+                            Vectors.dense(
+                                    0.86912739,
+                                    0.43617342,
+                                    0.80214764,
+                                    0.14376682,
+                                    0.70426097,
+                                    0.70458131)),
+                    Row.of(
+                            0.65665506,
+                            Vectors.dense(
+                                    0.21879211,
+                                    0.92486763,
+                                    0.44214076,
+                                    0.90931596,
+                                    0.05980922,
+                                    0.18428708)),
+                    Row.of(
+                            0.58147135,
+                            Vectors.dense(
+                                    0.04735528,
+                                    0.67488094,
+                                    0.59462478,
+                                    0.53331016,
+                                    0.04332406,
+                                    0.56143308)),
+                    Row.of(
+                            0.35603443,
+                            Vectors.dense(
+                                    0.32966845,
+                                    0.50296683,
+                                    0.11189432,
+                                    0.60719371,
+                                    0.56594464,
+                                    0.00676406)),
+                    Row.of(
+                            0.94534373,
+                            Vectors.dense(
+                                    0.61744171,
+                                    0.91212289,
+                                    0.79052413,
+                                    0.99208147,
+                                    0.95880176,
+                                    0.79196414)),
+                    Row.of(
+                            0.57458887,
+                            Vectors.dense(
+                                    0.28525096,
+                                    0.62491671,
+                                    0.4780938,
+                                    0.19567518,
+                                    0.38231745,
+                                    0.05387369)),
+                    Row.of(
+                            0.59026777,
+                            Vectors.dense(
+                                    0.45164841,
+                                    0.98200474,
+                                    0.1239427,
+                                    0.1193809,
+                                    0.73852306,
+                                    0.58730363)),
+                    Row.of(
+                            0.29894977,
+                            Vectors.dense(
+                                    0.47163253,
+                                    0.10712682,
+                                    0.22921857,
+                                    0.89996519,
+                                    0.41675354,
+                                    0.53585166)),
+                    Row.of(
+                            0.34056582,
+                            Vectors.dense(
+                                    0.00620852,
+                                    0.30064171,
+                                    0.43689317,
+                                    0.612149,
+                                    0.91819808,
+                                    0.62573667)),
+                    Row.of(
+                            0.64476446,
+                            Vectors.dense(
+                                    0.70599757,
+                                    0.14983372,
+                                    0.74606341,
+                                    0.83100699,
+                                    0.63372577,
+                                    0.43830988)),
+                    Row.of(
+                            0.53724782,
+                            Vectors.dense(
+                                    0.15257277,
+                                    0.56840962,
+                                    0.52822428,
+                                    0.95142876,
+                                    0.48035918,
+                                    0.50255956)),
+                    Row.of(
+                            0.5173021,
+                            Vectors.dense(
+                                    0.53687819,
+                                    0.81920207,
+                                    0.05711564,
+                                    0.66942174,
+                                    0.76711663,
+                                    0.70811536)),
+                    Row.of(
+                            0.94508275,
+                            Vectors.dense(
+                                    0.79686718,
+                                    0.55776083,
+                                    0.96583653,
+                                    0.1471569,
+                                    0.029647,
+                                    0.59389349)),
+                    Row.of(
+                            0.57739736,
+                            Vectors.dense(
+                                    0.1140657,
+                                    0.95080985,
+                                    0.96583653,
+                                    0.19361869,
+                                    0.45781165,
+                                    0.92040257)),
+                    Row.of(
+                            0.53877145,
+                            Vectors.dense(
+                                    0.87906916,
+                                    0.25261576,
+                                    0.34800879,
+                                    0.18258873,
+                                    0.90179605,
+                                    0.70652816)));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        env.getConfig().enableObjectReuse();
+        tEnv = StreamTableEnvironment.create(env);
+
+        selectorWithChiSqTest =
+                new UnivariateFeatureSelector()
+                        .setFeatureType("categorical")
+                        .setLabelType("categorical");
+        selectorWithANOVATest =
+                new UnivariateFeatureSelector()
+                        .setFeatureType("continuous")
+                        .setLabelType("categorical");
+        selectorWithFValueTest =
+                new UnivariateFeatureSelector()
+                        .setFeatureType("continuous")
+                        .setLabelType("continuous");
+        inputChiSqTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_CHISQ_DATA)).as("label", "features");
+        inputANOVATable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_ANOVA_DATA)).as("label", "features");
+        inputFValueTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_FVALUE_DATA)).as("label", "features");
+    }
+
+    private void transformAndVerify(
+            UnivariateFeatureSelector selector, Table table, int... expectedIndices)
+            throws Exception {
+        UnivariateFeatureSelectorModel model = selector.fit(table);
+        Table output = model.transform(table)[0];
+        verify(output, expectedIndices);
+    }
+
+    private void verify(Table table, int... expectedIndices) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) table).getTableEnvironment();
+        CloseableIterator<Row> rowIterator = tEnv.toDataStream(table).executeAndCollect();
+        while (rowIterator.hasNext()) {
+            Row row = rowIterator.next();
+            for (int i = 0; i < expectedIndices.length; i++) {
+                assertEquals(
+                        ((Vector) row.getField("features")).get(expectedIndices[i]),
+                        ((Vector) row.getField("output")).get(i),
+                        EPS);
+            }
+        }
+    }
+
+    @Test
+    public void testParam() {
+        UnivariateFeatureSelector selector = new UnivariateFeatureSelector();
+        assertEquals("features", selector.getFeaturesCol());
+        assertEquals("label", selector.getLabelCol());
+        assertEquals("output", selector.getOutputCol());
+        try {
+            selector.getFeatureType();
+            fail();
+        } catch (Throwable e) {
+            assertEquals("Parameter featureType's value should not be null", e.getMessage());
+        }
+        try {
+            selector.getLabelType();
+            fail();
+        } catch (Throwable e) {
+            assertEquals("Parameter labelType's value should not be null", e.getMessage());
+        }
+        assertEquals(NUM_TOP_FEATURES, selector.getSelectionMode());
+        assertNull(selector.getSelectionThreshold());
+
+        selector.setFeaturesCol("test_features")
+                .setLabelCol("test_label")
+                .setOutputCol("test_output")
+                .setFeatureType(CONTINUOUS)
+                .setLabelType(CATEGORICAL)
+                .setSelectionMode(FPR)
+                .setSelectionThreshold(0.01);
+
+        assertEquals("test_features", selector.getFeaturesCol());
+        assertEquals("test_label", selector.getLabelCol());
+        assertEquals("test_output", selector.getOutputCol());
+        assertEquals(CONTINUOUS, selector.getFeatureType());
+        assertEquals(CATEGORICAL, selector.getLabelType());
+        assertEquals(FPR, selector.getSelectionMode());
+        assertEquals(0.01, selector.getSelectionThreshold(), EPS);
+    }
+
+    @Test
+    public void testIncompatibleSelectionModeAndThreshold() {
+        UnivariateFeatureSelector selector =
+                new UnivariateFeatureSelector()
+                        .setFeatureType(CONTINUOUS)
+                        .setLabelType(CATEGORICAL)
+                        .setSelectionThreshold(50.1);
+
+        try {
+            selector.fit(inputANOVATable);
+            fail();
+        } catch (Throwable e) {
+            assertEquals(
+                    "SelectionThreshold needs to be a positive Integer "
+                            + "for selection mode numTopFeatures, but got 50.1.",
+                    e.getMessage());
+        }
+        try {
+            selector.setSelectionMode(FPR).setSelectionThreshold(1.1).fit(inputANOVATable);
+            fail();
+        } catch (Throwable e) {
+            assertEquals(
+                    "SelectionThreshold needs to be in the range [0, 1] "
+                            + "for selection mode fpr, but got 1.1.",
+                    e.getMessage());
+        }
+    }
+
+    @Test
+    public void testOutputSchema() {
+        Table tempTable = inputANOVATable.as("test_label", "test_features");
+        UnivariateFeatureSelector selector =
+                new UnivariateFeatureSelector()
+                        .setLabelCol("test_label")
+                        .setFeaturesCol("test_features")
+                        .setOutputCol("test_output")
+                        .setFeatureType("continuous")
+                        .setLabelType("categorical");
+
+        UnivariateFeatureSelectorModel model = selector.fit(tempTable);
+        Table output = model.transform(tempTable)[0];
+        assertEquals(
+                Arrays.asList("test_label", "test_features", "test_output"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testFitTransformWithNumTopFeatures() throws Exception {
+        transformAndVerify(selectorWithChiSqTest.setSelectionThreshold(2), inputChiSqTable, 0, 1);
+        transformAndVerify(selectorWithANOVATest.setSelectionThreshold(2), inputANOVATable, 0, 2);
+        transformAndVerify(selectorWithFValueTest.setSelectionThreshold(2), inputFValueTable, 2, 0);
+    }
+
+    @Test
+    public void testFitTransformWithPercentile() throws Exception {
+        transformAndVerify(
+                selectorWithChiSqTest.setSelectionMode("percentile").setSelectionThreshold(0.17),
+                inputChiSqTable,
+                0);
+        transformAndVerify(
+                selectorWithANOVATest.setSelectionMode("percentile").setSelectionThreshold(0.17),
+                inputANOVATable,
+                0);
+        transformAndVerify(
+                selectorWithFValueTest.setSelectionMode("percentile").setSelectionThreshold(0.17),
+                inputFValueTable,
+                2);
+    }
+
+    @Test
+    public void testFitTransformWithFPR() throws Exception {
+        transformAndVerify(
+                selectorWithChiSqTest.setSelectionMode("fpr").setSelectionThreshold(0.02),

Review Comment:
   Let's either use `FPR` or `"fpr"` in all usages in the test cases and examples.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] lindong28 commented on pull request #187: [FLINK-29601] Add Estimator and Transformer for UnivariateFeatureSelector

Posted by GitBox <gi...@apache.org>.
lindong28 commented on PR #187:
URL: https://github.com/apache/flink-ml/pull/187#issuecomment-1348091384

   Thanks for the update! LGTM.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #187: [FLINK-29601] Add Estimator and Transformer for UnivariateFeatureSelector

Posted by GitBox <gi...@apache.org>.
yunfengzhou-hub commented on code in PR #187:
URL: https://github.com/apache/flink-ml/pull/187#discussion_r1041657706


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelector.java:
##########
@@ -0,0 +1,303 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.stats.anovatest.ANOVATest;
+import org.apache.flink.ml.stats.chisqtest.ChiSqTest;
+import org.apache.flink.ml.stats.fvaluetest.FValueTest;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.IntStream;
+
+/**
+ * An Estimator which selects features based on univariate statistical tests against labels.
+ *
+ * <p>Currently, Flink supports three Univariate Feature Selectors: chi-squared, ANOVA F-test and
+ * F-value. User can choose Univariate Feature Selector by setting `featureType` and `labelType`,
+ * and Flink will pick the score function based on the specified `featureType` and `labelType`.
+ *
+ * <p>The following combination of `featureType` and `labelType` are supported:
+ *
+ * <ul>
+ *   <li>`featureType` `categorical` and `labelType` `categorical`: Flink uses chi-squared, i.e.
+ *       chi2 in sklearn.
+ *   <li>`featureType` `continuous` and `labelType` `categorical`: Flink uses ANOVA F-test,
+ *       f_classif in sklearn.
+ *   <li>`featureType` `continuous` and `labelType` `continuous`: Flink uses F-value, i.e.
+ *       f_regression in sklearn.
+ * </ul>
+ *
+ * <p>The `UnivariateFeatureSelector` supports different selection modes:
+ *
+ * <ul>
+ *   <li>numTopFeatures: chooses a fixed number of top features according to a hypothesis.
+ *   <li>percentile: similar to numTopFeatures but chooses a fraction of all features instead of a
+ *       fixed number.
+ *   <li>fpr: chooses all features whose p-value are below a threshold, thus controlling the false
+ *       positive rate of selection.
+ *   <li>fdr: uses the <a
+ *       href="https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure">
+ *       Benjamini-Hochberg procedure</a> to choose all features whose false discovery rate is below
+ *       a threshold.
+ *   <li>fwe: chooses all features whose p-values are below a threshold. The threshold is scaled by
+ *       1/numFeatures, thus controlling the family-wise error rate of selection.
+ * </ul>
+ *
+ * <p>By default, the selection mode is `numTopFeatures`.
+ */
+public class UnivariateFeatureSelector
+        implements Estimator<UnivariateFeatureSelector, UnivariateFeatureSelectorModel>,
+                UnivariateFeatureSelectorParams<UnivariateFeatureSelector> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public UnivariateFeatureSelector() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public UnivariateFeatureSelectorModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String featuresCol = getFeaturesCol();
+        final String labelCol = getLabelCol();
+        final String featureType = getFeatureType();
+        final String labelType = getLabelType();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        Table output;
+        if (CATEGORICAL.equals(featureType) && CATEGORICAL.equals(labelType)) {
+            output =
+                    new ChiSqTest()
+                            .setFeaturesCol(featuresCol)
+                            .setLabelCol(labelCol)
+                            .setFlatten(true)
+                            .transform(inputs[0])[0];
+        } else if (CONTINUOUS.equals(featureType) && CATEGORICAL.equals(labelType)) {
+            output =
+                    new ANOVATest()
+                            .setFeaturesCol(featuresCol)
+                            .setLabelCol(labelCol)
+                            .setFlatten(true)
+                            .transform(inputs[0])[0];
+        } else if (CONTINUOUS.equals(featureType) && CONTINUOUS.equals(labelType)) {
+            output =
+                    new FValueTest()
+                            .setFeaturesCol(featuresCol)
+                            .setLabelCol(labelCol)
+                            .setFlatten(true)
+                            .transform(inputs[0])[0];
+        } else {
+            throw new IllegalArgumentException(
+                    String.format(
+                            "Unsupported combination: featureType=%s, labelType=%s.",
+                            featureType, labelType));
+        }
+        DataStream<UnivariateFeatureSelectorModelData> modelData =
+                tEnv.toDataStream(output)
+                        .transform(
+                                "selectIndicesFromPValues",
+                                TypeInformation.of(UnivariateFeatureSelectorModelData.class),
+                                new SelectIndicesFromPValuesOperator(
+                                        getSelectionMode(), getActualSelectionThreshold()))
+                        .setParallelism(1);
+        UnivariateFeatureSelectorModel model =
+                new UnivariateFeatureSelectorModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, getParamMap());
+        return model;
+    }
+
+    private double getActualSelectionThreshold() {
+        Double threshold = getSelectionThreshold();
+        if (threshold == null) {
+            String selectionMode = getSelectionMode();
+            if (NUM_TOP_FEATURES.equals(selectionMode)) {
+                threshold = 50.0;
+            } else if (PERCENTILE.equals(selectionMode)) {
+                threshold = 0.1;
+            } else {
+                threshold = 0.05;
+            }
+        } else {
+            if (NUM_TOP_FEATURES.equals(getSelectionMode())) {
+                Preconditions.checkArgument(
+                        threshold >= 1 && threshold.intValue() == threshold,
+                        "SelectionThreshold needs to be a positive Integer "
+                                + "for selection mode numTopFeatures, but got %s.",
+                        threshold);
+            } else {
+                Preconditions.checkArgument(
+                        threshold >= 0 && threshold <= 1,
+                        "SelectionThreshold needs to be in the range [0, 1] "
+                                + "for selection mode %s, but got %s.",
+                        getSelectionMode(),
+                        threshold);
+            }
+        }
+        return threshold;
+    }
+
+    private static class SelectIndicesFromPValuesOperator
+            extends AbstractStreamOperator<UnivariateFeatureSelectorModelData>
+            implements OneInputStreamOperator<Row, UnivariateFeatureSelectorModelData>,
+                    BoundedOneInput {
+        private final String selectionMode;
+        private final double threshold;
+
+        private List<Tuple2<Double, Integer>> pValuesAndIndices;
+        private ListState<Tuple2<Double, Integer>> pValuesAndIndicesState;
+
+        public SelectIndicesFromPValuesOperator(String selectionMode, double threshold) {
+            this.selectionMode = selectionMode;
+            this.threshold = threshold;
+        }
+
+        @Override
+        public void endInput() {
+            List<Integer> indices = new ArrayList<>();
+
+            switch (selectionMode) {
+                case NUM_TOP_FEATURES:
+                    pValuesAndIndices.sort(
+                            Comparator.comparingDouble((Tuple2<Double, Integer> t) -> t.f0)
+                                    .thenComparingInt(t -> t.f1));
+                    IntStream.range(0, Math.min(pValuesAndIndices.size(), (int) threshold))
+                            .forEach(i -> indices.add(pValuesAndIndices.get(i).f1));
+                    break;
+                case PERCENTILE:
+                    pValuesAndIndices.sort(
+                            Comparator.comparingDouble((Tuple2<Double, Integer> t) -> t.f0)
+                                    .thenComparingInt(t -> t.f1));
+                    IntStream.range(
+                                    0,
+                                    Math.min(
+                                            pValuesAndIndices.size(),
+                                            (int) (pValuesAndIndices.size() * threshold)))
+                            .forEach(i -> indices.add(pValuesAndIndices.get(i).f1));
+                    break;
+                case FPR:
+                    pValuesAndIndices.stream()
+                            .filter(x -> x.f0 < threshold)
+                            .forEach(x -> indices.add(x.f1));
+                    break;
+                case FDR:
+                    pValuesAndIndices.sort(
+                            Comparator.comparingDouble((Tuple2<Double, Integer> t) -> t.f0)
+                                    .thenComparingInt(t -> t.f1));
+
+                    int maxIndex = -1;
+                    for (int i = 0; i < pValuesAndIndices.size(); i++) {
+                        if (pValuesAndIndices.get(i).f0
+                                < (threshold / pValuesAndIndices.size()) * (i + 1)) {
+                            maxIndex = Math.max(maxIndex, i);
+                        }
+                    }
+                    if (maxIndex >= 0) {
+                        pValuesAndIndices.sort(Comparator.comparing(t -> t.f0));

Review Comment:
   Let's compare pValues and then indices in FDR strategy as well.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java:
##########
@@ -63,11 +63,11 @@ public VarianceThresholdSelectorModel fit(Table... inputs) {
         final String inputCol = getInputCol();
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
-        DataStream<DenseVector> inputData =
+        DataStream<Vector> inputData =

Review Comment:
   Let's describe the changes made to `VarianceThresholdSelector` in this PR's description.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [flink-ml] jiangxin369 commented on a diff in pull request #187: [FLINK-29601] Add Estimator and Transformer for UnivariateFeatureSelector

Posted by GitBox <gi...@apache.org>.
jiangxin369 commented on code in PR #187:
URL: https://github.com/apache/flink-ml/pull/187#discussion_r1037760329


##########
docs/content/docs/operators/feature/univariategeatureselector.md:
##########
@@ -0,0 +1,220 @@
+---
+title: "Univariate Feature Selector"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/univariatefeatureselector.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## Univariate Feature Selector
+Univariate Feature Selector is an algorithm that selects features based on 
+univariate statistical tests against labels.
+
+Currently, Flink supports three Univariate Feature Selectors: chi-squared, 
+ANOVA F-test and F-value. User can choose Univariate Feature Selector by 
+setting `featureType` and `labelType`, and Flink will pick the score function
+based on the specified `featureType` and `labelType`.
+
+The following combination of `featureType` and `labelType` are supported:
+
+<ul>
+    <li>`featureType` `categorical` and `labelType` `categorical`: Flink uses 
+        chi-squared, i.e. chi2 in sklearn.
+    <li>`featureType` `continuous` and `labelType` `categorical`: Flink uses 
+        ANOVA F-test, f_classif in sklearn.
+    <li>`featureType` `continuous` and `labelType` `continuous`: Flink uses 
+        F-value, i.e. f_regression in sklearn.
+</ul>
+

Review Comment:
   Sure, but why not keep consistent with Javadoc just like `VectorAssembler` and `FeatureHasher`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org