You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/11/28 14:09:02 UTC

[flink-ml] branch master updated: [FLINK-30159] Add Transformer for ANOVATest

This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new d6a5b42  [FLINK-30159] Add Transformer for ANOVATest
d6a5b42 is described below

commit d6a5b42aae2a1a84789dfe0a35c7e97461b5d418
Author: JiangXin <ji...@alibaba-inc.com>
AuthorDate: Mon Nov 28 22:08:56 2022 +0800

    [FLINK-30159] Add Transformer for ANOVATest
    
    This closes #180.
---
 .../param/HasFlatten.java}                         |  14 +-
 .../apache/flink/ml/stats/anovatest/ANOVATest.java | 287 +++++++++++++++
 .../ANOVATestParams.java}                          |  23 +-
 .../apache/flink/ml/stats/chisqtest/ChiSqTest.java |   3 +-
 .../flink/ml/stats/chisqtest/ChiSqTestParams.java  |  19 +-
 .../org/apache/flink/ml/stats/ANOVATestTest.java   | 410 +++++++++++++++++++++
 .../pyflink/ml/core/tests/test_param.py            |  16 +-
 flink-ml-python/pyflink/ml/lib/param.py            |  23 +-
 flink-ml-python/pyflink/ml/lib/stats/chisqtest.py  |   2 +-
 .../ml/lib/tests/test_ml_lib_completeness.py       |  14 +-
 10 files changed, 760 insertions(+), 51 deletions(-)

diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFlatten.java
similarity index 79%
copy from flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
copy to flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFlatten.java
index e9d3ec9..899e7a5 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFlatten.java
@@ -16,19 +16,15 @@
  * limitations under the License.
  */
 
-package org.apache.flink.ml.stats.chisqtest;
+package org.apache.flink.ml.common.param;
 
-import org.apache.flink.ml.common.param.HasFeaturesCol;
-import org.apache.flink.ml.common.param.HasLabelCol;
 import org.apache.flink.ml.param.BooleanParam;
 import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared flatten param. */
+public interface HasFlatten<T> extends WithParams<T> {
 
-/**
- * Params for {@link ChiSqTest}.
- *
- * @param <T> The class type of this instance.
- */
-public interface ChiSqTestParams<T> extends HasFeaturesCol<T>, HasLabelCol<T> {
     Param<Boolean> FLATTEN =
             new BooleanParam(
                     "flatten",
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java
new file mode 100644
index 0000000..b58fdc8
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats.anovatest;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.param.HasFlatten;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.math3.distribution.FDistribution;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.IntStream;
+
+/**
+ * An AlgoOperator which implements the ANOVA test algorithm.
+ *
+ * <p>See <a href="https://en.wikipedia.org/wiki/Analysis_of_variance">Wikipedia</a> for more
+ * information on ANOVA test.
+ *
+ * <p>The input of this algorithm is a table containing a labelColumn of numerical type and a
+ * featuresColumn of vector type. Each index in the input vector represents a feature to be tested.
+ * By default, the output of this algorithm is a table containing a single row with the following
+ * columns, each of which has one value per feature.
+ *
+ * <ul>
+ *   <li>"pValues": vector
+ *   <li>"degreesOfFreedom": int array
+ *   <li>"fValues": vector
+ * </ul>
+ *
+ * <p>The output of this algorithm can be flattened to multiple rows by setting {@link
+ * HasFlatten#FLATTEN} to true, which would contain the following columns:
+ *
+ * <ul>
+ *   <li>"featureIndex": int
+ *   <li>"pValue": double
+ *   <li>"degreeOfFreedom": int
+ *   <li>"fValues": double
+ * </ul>
+ */
+public class ANOVATest implements AlgoOperator<ANOVATest>, ANOVATestParams<ANOVATest> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public ANOVATest() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String featuresCol = getFeaturesCol();
+        final String labelCol = getLabelCol();
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> inputData =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                (MapFunction<Row, Tuple2<Vector, Double>>)
+                                        row -> {
+                                            Number number = (Number) row.getField(labelCol);
+                                            Preconditions.checkNotNull(
+                                                    number, "Input data must contain label value.");
+                                            return new Tuple2<>(
+                                                    ((Vector) row.getField(featuresCol)),
+                                                    number.doubleValue());
+                                        })
+                        .returns(Types.TUPLE(VectorTypeInfo.INSTANCE, Types.DOUBLE));
+        DataStream<List<Row>> streamWithANOVA =
+                DataStreamUtils.aggregate(inputData, new ANOVAAggregator());
+        return new Table[] {convertToTable(tEnv, streamWithANOVA, getFlatten())};
+    }
+
+    /** Computes the p-value, fValues and the number of degrees of freedom of input features. */
+    @SuppressWarnings("unchecked")
+    private static class ANOVAAggregator
+            implements AggregateFunction<
+                    Tuple2<Vector, Double>,
+                    Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[],
+                    List<Row>> {
+        @Override
+        public Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] createAccumulator() {
+            return new Tuple3[0];
+        }
+
+        @Override
+        public Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] add(
+                Tuple2<Vector, Double> featuresAndLabel,
+                Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] acc) {
+            Vector features = featuresAndLabel.f0;
+            double label = featuresAndLabel.f1;
+            int numOfFeatures = features.size();
+            if (acc.length == 0) {
+                acc = new Tuple3[features.size()];
+                for (int i = 0; i < numOfFeatures; i++) {
+                    acc[i] = Tuple3.of(0.0, 0.0, new HashMap<>());
+                }
+            }
+            for (int i = 0; i < numOfFeatures; i++) {
+                double featureValue = features.get(i);
+                acc[i].f0 += featureValue;
+                acc[i].f1 += featureValue * featureValue;
+
+                if (acc[i].f2.containsKey(label)) {
+                    acc[i].f2.get(label).f0 += featureValue;
+                    acc[i].f2.get(label).f1 += 1L;
+                } else {
+                    acc[i].f2.put(label, Tuple2.of(featureValue, 1L));
+                }
+            }
+            return acc;
+        }
+
+        @Override
+        public List<Row> getResult(
+                Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] acc) {
+            List<Row> results = new ArrayList<>();
+            for (int i = 0; i < acc.length; i++) {
+                Tuple3<Double, Long, Double> resultOfANOVA =
+                        computeANOVA(acc[i].f0, acc[i].f1, acc[i].f2);
+                results.add(Row.of(i, resultOfANOVA.f0, resultOfANOVA.f1, resultOfANOVA.f2));
+            }
+            return results;
+        }
+
+        @Override
+        public Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] merge(
+                Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] acc1,
+                Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] acc2) {
+            if (acc1.length == 0) {
+                return acc2;
+            }
+            if (acc2.length == 0) {
+                return acc1;
+            }
+            IntStream.range(0, acc1.length)
+                    .forEach(
+                            i -> {
+                                acc2[i].f0 += acc1[i].f0;
+                                acc2[i].f1 += acc1[i].f1;
+                                acc1[i].f2.forEach(
+                                        (k, v) -> {
+                                            if (acc2[i].f2.containsKey(k)) {
+                                                acc2[i].f2.get(k).f0 += v.f0;
+                                                acc2[i].f2.get(k).f1 += v.f1;
+                                            } else {
+                                                acc2[i].f2.put(k, v);
+                                            }
+                                        });
+                            });
+            return acc2;
+        }
+
+        private Tuple3<Double, Long, Double> computeANOVA(
+                double sum, double sumOfSq, HashMap<Double, Tuple2<Double, Long>> summary) {
+            long numOfClasses = summary.size();
+
+            long numOfSamples = summary.values().stream().mapToLong(t -> t.f1).sum();
+
+            double sqSum = sum * sum;
+
+            double ssTot = sumOfSq - sqSum / numOfSamples;
+
+            double totalSqSum = 0;
+            for (Tuple2<Double, Long> t : summary.values()) {
+                totalSqSum += t.f0 * t.f0 / t.f1;
+            }
+
+            double sumOfSqBetween = totalSqSum - (sqSum / numOfSamples);
+
+            double sumOfSqWithin = ssTot - sumOfSqBetween;
+
+            long degreeOfFreedomBetween = numOfClasses - 1;
+            Preconditions.checkArgument(
+                    degreeOfFreedomBetween > 0, "Num of classes should be positive.");
+
+            long degreeOfFreedomWithin = numOfSamples - numOfClasses;
+            Preconditions.checkArgument(
+                    degreeOfFreedomWithin > 0,
+                    "Num of samples should be greater than num of classes.");
+
+            double meanSqBetween = sumOfSqBetween / degreeOfFreedomBetween;
+
+            double meanSqWithin = sumOfSqWithin / degreeOfFreedomWithin;
+
+            double fValue = meanSqBetween / meanSqWithin;
+
+            FDistribution fd = new FDistribution(degreeOfFreedomBetween, degreeOfFreedomWithin);
+            double pValue = 1 - fd.cumulativeProbability(fValue);
+
+            long degreeOfFreedom = degreeOfFreedomBetween + degreeOfFreedomWithin;
+
+            return Tuple3.of(pValue, degreeOfFreedom, fValue);
+        }
+    }
+
+    private Table convertToTable(
+            StreamTableEnvironment tEnv, DataStream<List<Row>> datastream, boolean flatten) {
+        if (flatten) {
+            DataStream<Row> output =
+                    datastream
+                            .flatMap(
+                                    (FlatMapFunction<List<Row>, Row>)
+                                            (list, collector) -> list.forEach(collector::collect))
+                            .setParallelism(1)
+                            .returns(Types.ROW(Types.INT, Types.DOUBLE, Types.LONG, Types.DOUBLE));
+            return tEnv.fromDataStream(output)
+                    .as("featureIndex", "pValue", "degreeOfFreedom", "fValue");
+        } else {
+            DataStream<Tuple3<DenseVector, long[], DenseVector>> output =
+                    datastream.map(
+                            new MapFunction<List<Row>, Tuple3<DenseVector, long[], DenseVector>>() {
+                                @Override
+                                public Tuple3<DenseVector, long[], DenseVector> map(
+                                        List<Row> rows) {
+                                    int numOfFeatures = rows.size();
+                                    DenseVector pValues = new DenseVector(numOfFeatures);
+                                    DenseVector fValues = new DenseVector(numOfFeatures);
+                                    long[] degrees = new long[numOfFeatures];
+
+                                    for (int i = 0; i < numOfFeatures; i++) {
+                                        Row row = rows.get(i);
+                                        pValues.set(i, (double) row.getField(1));
+                                        degrees[i] = (long) row.getField(2);
+                                        fValues.set(i, (double) row.getField(3));
+                                    }
+                                    return Tuple3.of(pValues, degrees, fValues);
+                                }
+                            });
+            return tEnv.fromDataStream(output).as("pValues", "degreesOfFreedom", "fValues");
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static ANOVATest load(StreamTableEnvironment tEnv, String path) throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATestParams.java
similarity index 60%
copy from flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
copy to flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATestParams.java
index e9d3ec9..2ac44a8 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATestParams.java
@@ -16,30 +16,15 @@
  * limitations under the License.
  */
 
-package org.apache.flink.ml.stats.chisqtest;
+package org.apache.flink.ml.stats.anovatest;
 
 import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasFlatten;
 import org.apache.flink.ml.common.param.HasLabelCol;
-import org.apache.flink.ml.param.BooleanParam;
-import org.apache.flink.ml.param.Param;
 
 /**
- * Params for {@link ChiSqTest}.
+ * Params for {@link ANOVATest}.
  *
  * @param <T> The class type of this instance.
  */
-public interface ChiSqTestParams<T> extends HasFeaturesCol<T>, HasLabelCol<T> {
-    Param<Boolean> FLATTEN =
-            new BooleanParam(
-                    "flatten",
-                    "If false, the returned table contains only a single row, otherwise, one row per feature.",
-                    false);
-
-    default boolean getFlatten() {
-        return get(FLATTEN);
-    }
-
-    default T setFlatten(boolean value) {
-        return set(FLATTEN, value);
-    }
-}
+public interface ANOVATestParams<T> extends HasFeaturesCol<T>, HasLabelCol<T>, HasFlatten<T> {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java
index 7e625ff..ea479db 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java
@@ -32,6 +32,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.iteration.operator.OperatorStateUtils;
 import org.apache.flink.ml.api.AlgoOperator;
 import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.param.HasFlatten;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
@@ -87,7 +88,7 @@ import java.util.stream.Collectors;
  * </ul>
  *
  * <p>The output of this algorithm can be flattened to multiple rows by setting {@link
- * ChiSqTestParams#FLATTEN}, which would contain the following columns:
+ * HasFlatten#FLATTEN} to true, which would contain the following columns:
  *
  * <ul>
  *   <li>"featureIndex": int
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
index e9d3ec9..4ad1cb4 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
@@ -19,27 +19,12 @@
 package org.apache.flink.ml.stats.chisqtest;
 
 import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasFlatten;
 import org.apache.flink.ml.common.param.HasLabelCol;
-import org.apache.flink.ml.param.BooleanParam;
-import org.apache.flink.ml.param.Param;
 
 /**
  * Params for {@link ChiSqTest}.
  *
  * @param <T> The class type of this instance.
  */
-public interface ChiSqTestParams<T> extends HasFeaturesCol<T>, HasLabelCol<T> {
-    Param<Boolean> FLATTEN =
-            new BooleanParam(
-                    "flatten",
-                    "If false, the returned table contains only a single row, otherwise, one row per feature.",
-                    false);
-
-    default boolean getFlatten() {
-        return get(FLATTEN);
-    }
-
-    default T setFlatten(boolean value) {
-        return set(FLATTEN, value);
-    }
-}
+public interface ChiSqTestParams<T> extends HasFeaturesCol<T>, HasLabelCol<T>, HasFlatten<T> {}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java
new file mode 100644
index 0000000..0abdec7
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java
@@ -0,0 +1,410 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.stats.anovatest.ANOVATest;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/** Tests the {@link ANOVATest}. */
+public class ANOVATestTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table denseInputTable;
+    private Table sparseInputTable;
+
+    private static final double EPS = 1.0e-5;
+    private static final List<Row> DENSE_INPUT_DATA =
+            Arrays.asList(
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    0.85956061,
+                                    0.1645695,
+                                    0.48347596,
+                                    0.92102727,
+                                    0.42855644,
+                                    0.05746009)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    0.92500743,
+                                    0.65760154,
+                                    0.13295284,
+                                    0.53344893,
+                                    0.8994776,
+                                    0.24836496)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    0.03017182,
+                                    0.07244715,
+                                    0.87416449,
+                                    0.55843035,
+                                    0.91604736,
+                                    0.63346045)),
+                    Row.of(
+                            5,
+                            Vectors.dense(
+                                    0.28325261,
+                                    0.36536881,
+                                    0.09223386,
+                                    0.37251258,
+                                    0.34742278,
+                                    0.70517077)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    0.64850904,
+                                    0.04090877,
+                                    0.21173176,
+                                    0.00148992,
+                                    0.13897166,
+                                    0.21182539)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    0.02609493,
+                                    0.44608735,
+                                    0.23910531,
+                                    0.95449222,
+                                    0.90763182,
+                                    0.8624905)),
+                    Row.of(
+                            5,
+                            Vectors.dense(
+                                    0.09158744,
+                                    0.97745235,
+                                    0.41150139,
+                                    0.45830467,
+                                    0.52590925,
+                                    0.29441554)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    0.97211594,
+                                    0.1814442,
+                                    0.30340642,
+                                    0.17445413,
+                                    0.52756958,
+                                    0.02069296)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    0.06354593,
+                                    0.63527231,
+                                    0.49620335,
+                                    0.0141264,
+                                    0.62722219,
+                                    0.63497507)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    0.10814149,
+                                    0.8296426,
+                                    0.51775217,
+                                    0.57068344,
+                                    0.54633305,
+                                    0.12714921)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    0.72731796,
+                                    0.94010124,
+                                    0.45007811,
+                                    0.87650674,
+                                    0.53735565,
+                                    0.49568415)),
+                    Row.of(
+                            2,
+                            Vectors.dense(
+                                    0.41827208,
+                                    0.85100628,
+                                    0.38685271,
+                                    0.60689503,
+                                    0.21784097,
+                                    0.91294433)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    0.65843656,
+                                    0.5880859,
+                                    0.18862706,
+                                    0.856398,
+                                    0.18029327,
+                                    0.94851926)),
+                    Row.of(
+                            4,
+                            Vectors.dense(
+                                    0.3841634,
+                                    0.25138793,
+                                    0.96746644,
+                                    0.77048045,
+                                    0.44685196,
+                                    0.19813854)),
+                    Row.of(
+                            5,
+                            Vectors.dense(
+                                    0.65982267,
+                                    0.23024125,
+                                    0.13598434,
+                                    0.60144265,
+                                    0.57848927,
+                                    0.85623564)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    0.35764189,
+                                    0.47623815,
+                                    0.5459232,
+                                    0.79508298,
+                                    0.14462443,
+                                    0.01802919)),
+                    Row.of(
+                            5,
+                            Vectors.dense(
+                                    0.38532153,
+                                    0.90614554,
+                                    0.86629571,
+                                    0.13988735,
+                                    0.32062385,
+                                    0.00179492)),
+                    Row.of(
+                            3,
+                            Vectors.dense(
+                                    0.2142368,
+                                    0.28306022,
+                                    0.59481646,
+                                    0.42567028,
+                                    0.52207663,
+                                    0.78082401)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    0.20788283,
+                                    0.76861782,
+                                    0.59595468,
+                                    0.62103642,
+                                    0.17781246,
+                                    0.77655345)),
+                    Row.of(
+                            1,
+                            Vectors.dense(
+                                    0.1751708,
+                                    0.4547537,
+                                    0.46187865,
+                                    0.79781199,
+                                    0.05104487,
+                                    0.42406092)));
+
+    private static final List<Row> SPARSE_INPUT_DATA =
+            Arrays.asList(
+                    Row.of(3, Vectors.dense(6.0, 7.0, 0.0, 7.0, 6.0, 0.0, 0.0).toSparse()),
+                    Row.of(1, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0, 9.0, 0.0).toSparse()),
+                    Row.of(3, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0, 0.0).toSparse()),
+                    Row.of(2, Vectors.dense(0.0, 9.0, 8.0, 5.0, 6.0, 4.0, 0.0).toSparse()),
+                    Row.of(2, Vectors.dense(8.0, 9.0, 6.0, 5.0, 4.0, 4.0, 0.0).toSparse()),
+                    Row.of(3, Vectors.dense(Double.NaN, 9.0, 6.0, 4.0, 0.0, 0.0, 0.0).toSparse()));
+
+    private static final Row EXPECTED_OUTPUT_DENSE =
+            Row.of(
+                    Vectors.dense(
+                            0.64137831, 0.14830724, 0.69858474, 0.28038169, 0.86759161, 0.81608606),
+                    new long[] {19, 19, 19, 19, 19, 19},
+                    Vectors.dense(
+                            0.64110932, 1.98689258, 0.55499714, 1.40340562, 0.30881722, 0.3848595));
+
+    private static final List<Row> EXPECTED_FLATTENED_OUTPUT_DENSE =
+            Arrays.asList(
+                    Row.of(0, 0.64137831, 19, 0.64110932),
+                    Row.of(1, 0.14830724, 19, 1.98689258),
+                    Row.of(2, 0.69858474, 19, 0.55499714),
+                    Row.of(3, 0.28038169, 19, 1.40340562),
+                    Row.of(4, 0.86759161, 19, 0.30881722),
+                    Row.of(5, 0.81608606, 19, 0.3848595));
+
+    private static final Row EXPECTED_OUTPUT_SPARSE =
+            Row.of(
+                    Vectors.dense(
+                            Double.NaN,
+                            0.71554175,
+                            0.34278574,
+                            0.45824059,
+                            0.84633632,
+                            0.15673368,
+                            Double.NaN),
+                    new long[] {5, 5, 5, 5, 5, 5, 5},
+                    Vectors.dense(
+                            Double.NaN, 0.375, 1.5625, 1.02364865, 0.17647059, 3.66, Double.NaN));
+
+    private static final List<Row> EXPECTED_FLATTENED_OUTPUT_SPARSE =
+            Arrays.asList(
+                    Row.of(0, Double.NaN, 5, Double.NaN),
+                    Row.of(1, 0.71554175, 5, 0.375),
+                    Row.of(2, 0.34278574, 5, 1.5625),
+                    Row.of(3, 0.45824059, 5, 1.02364865),
+                    Row.of(4, 0.84633632, 5, 0.17647059),
+                    Row.of(5, 0.15673368, 5, 3.66),
+                    Row.of(6, Double.NaN, 5, Double.NaN));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        env.getConfig().enableObjectReuse();
+        tEnv = StreamTableEnvironment.create(env);
+        denseInputTable =
+                tEnv.fromDataStream(env.fromCollection(DENSE_INPUT_DATA)).as("label", "features");
+        sparseInputTable =
+                tEnv.fromDataStream(env.fromCollection(SPARSE_INPUT_DATA)).as("label", "features");
+    }
+
+    private static void verifyFlattenTransformationResult(Table output, List<Row> expected)
+            throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+        DataStream<Row> outputDataStream = tEnv.toDataStream(output);
+        List<Row> results = IteratorUtils.toList(outputDataStream.executeAndCollect());
+        assertEquals(expected.size(), results.size());
+
+        results.sort(Comparator.comparing(r -> String.valueOf(r.getField(0))));
+        expected.sort(Comparator.comparing(r -> String.valueOf(r.getField(0))));
+
+        for (int i = 0; i < expected.size(); i++) {
+            assertEquals(expected.get(i).getArity(), results.get(i).getArity());
+            for (int j = 0; j < expected.get(i).getArity(); j++) {
+                assertEquals(
+                        Double.valueOf(expected.get(i).getField(j).toString()),
+                        Double.valueOf(results.get(i).getField(j).toString()),
+                        EPS);
+            }
+        }
+    }
+
+    private static void verifyTransformationResult(Table output, Row expected) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+        DataStream<Row> outputDataStream = tEnv.toDataStream(output);
+        List<Row> results = IteratorUtils.toList(outputDataStream.executeAndCollect());
+        assertEquals(1, results.size());
+
+        Row result = results.get(0);
+        assertEquals(3, result.getArity());
+        assertArrayEquals(
+                ((Vector) expected.getField(0)).toArray(),
+                ((Vector) result.getField(0)).toArray(),
+                EPS);
+        assertArrayEquals((long[]) expected.getField(1), (long[]) result.getField(1));
+        assertArrayEquals(
+                ((Vector) expected.getField(2)).toArray(),
+                ((Vector) result.getField(2)).toArray(),
+                EPS);
+    }
+
+    @Test
+    public void testParam() {
+        ANOVATest anovaTest = new ANOVATest();
+        assertEquals("label", anovaTest.getLabelCol());
+        assertEquals("features", anovaTest.getFeaturesCol());
+        assertFalse(anovaTest.getFlatten());
+
+        anovaTest.setLabelCol("test_label").setFeaturesCol("test_features").setFlatten(true);
+
+        assertEquals("test_features", anovaTest.getFeaturesCol());
+        assertEquals("test_label", anovaTest.getLabelCol());
+        assertTrue(anovaTest.getFlatten());
+    }
+
+    @Test
+    public void testOutputSchema() {
+        ANOVATest anovaTest =
+                new ANOVATest().setFeaturesCol("test_features").setLabelCol("test_label");
+        Table output = anovaTest.transform(denseInputTable)[0];
+        assertEquals(
+                Arrays.asList("pValues", "degreesOfFreedom", "fValues"),
+                output.getResolvedSchema().getColumnNames());
+
+        anovaTest.setFlatten(true);
+        output = anovaTest.transform(denseInputTable)[0];
+        assertEquals(
+                Arrays.asList("featureIndex", "pValue", "degreeOfFreedom", "fValue"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testTransform() throws Exception {
+        ANOVATest anovaTest = new ANOVATest();
+
+        Table denseOutput = anovaTest.transform(denseInputTable)[0];
+        verifyTransformationResult(denseOutput, EXPECTED_OUTPUT_DENSE);
+
+        Table sparseOutput = anovaTest.transform(sparseInputTable)[0];
+        verifyTransformationResult(sparseOutput, EXPECTED_OUTPUT_SPARSE);
+    }
+
+    @Test
+    public void testTransformWithFlatten() throws Exception {
+        ANOVATest anovaTest = new ANOVATest().setFlatten(true);
+
+        Table denseOutput = anovaTest.transform(denseInputTable)[0];
+        verifyFlattenTransformationResult(denseOutput, EXPECTED_FLATTENED_OUTPUT_DENSE);
+
+        Table sparseOutput = anovaTest.transform(sparseInputTable)[0];
+        verifyFlattenTransformationResult(sparseOutput, EXPECTED_FLATTENED_OUTPUT_SPARSE);
+    }
+
+    @Test
+    public void testSaveLoadAndTransform() throws Exception {
+        ANOVATest anovaTest = new ANOVATest();
+        ANOVATest loadedANOVATest =
+                TestUtils.saveAndReload(tEnv, anovaTest, tempFolder.newFolder().getAbsolutePath());
+        Table output = loadedANOVATest.transform(denseInputTable)[0];
+        verifyTransformationResult(output, EXPECTED_OUTPUT_DENSE);
+    }
+}
diff --git a/flink-ml-python/pyflink/ml/core/tests/test_param.py b/flink-ml-python/pyflink/ml/core/tests/test_param.py
index 4b1e40e..7213fb9 100644
--- a/flink-ml-python/pyflink/ml/core/tests/test_param.py
+++ b/flink-ml-python/pyflink/ml/core/tests/test_param.py
@@ -22,7 +22,7 @@ from pyflink.ml.core.param import Param
 from pyflink.ml.lib.param import HasDistanceMeasure, HasFeaturesCol, HasGlobalBatchSize, \
     HasHandleInvalid, HasInputCols, HasLabelCol, HasLearningRate, HasMaxIter, HasMultiClass, \
     HasOutputCols, HasPredictionCol, HasRawPredictionCol, HasReg, HasSeed, HasTol, HasWeightCol, \
-    HasWindows, HasRelativeError
+    HasWindows, HasRelativeError, HasFlatten
 
 from pyflink.ml.core.windows import GlobalWindows, CountTumblingWindows
 
@@ -30,7 +30,7 @@ from pyflink.ml.core.windows import GlobalWindows, CountTumblingWindows
 class TestParams(HasDistanceMeasure, HasFeaturesCol, HasGlobalBatchSize, HasHandleInvalid,
                  HasInputCols, HasLabelCol, HasLearningRate, HasMaxIter, HasMultiClass,
                  HasOutputCols, HasPredictionCol, HasRawPredictionCol, HasReg, HasSeed, HasTol,
-                 HasWeightCol, HasWindows, HasRelativeError):
+                 HasWeightCol, HasWindows, HasRelativeError, HasFlatten):
     def __init__(self):
         self._param_map = {}
 
@@ -227,3 +227,15 @@ class ParamTests(unittest.TestCase):
 
         param.set_relative_error(0.1)
         self.assertEqual(param.get_relative_error(), 0.1)
+
+    def test_flatten(self):
+        param = TestParams()
+        flatten = param.FLATTEN
+        self.assertEqual(flatten.name, "flatten")
+        self.assertEqual(flatten.description,
+                         "If false, the returned table contains only a "
+                         "single row, otherwise, one row per feature.")
+        self.assertFalse(flatten.default_value)
+
+        param.set_flatten(True)
+        self.assertTrue(param.get_flatten())
diff --git a/flink-ml-python/pyflink/ml/lib/param.py b/flink-ml-python/pyflink/ml/lib/param.py
index 4ca3aa0..e230e04 100644
--- a/flink-ml-python/pyflink/ml/lib/param.py
+++ b/flink-ml-python/pyflink/ml/lib/param.py
@@ -19,7 +19,7 @@ from abc import ABC
 from typing import Tuple
 
 from pyflink.ml.core.param import WithParams, Param, ParamValidators, StringParam, IntParam, \
-    StringArrayParam, FloatParam, WindowsParam
+    StringArrayParam, FloatParam, WindowsParam, BooleanParam
 from pyflink.ml.core.windows import Windows, GlobalWindows
 
 
@@ -560,3 +560,24 @@ class HasRelativeError(WithParams, ABC):
     @property
     def relative_error(self):
         return self.get(self.RELATIVE_ERROR)
+
+
+class HasFlatten(WithParams, ABC):
+    """
+    Interface for shared flatten param.
+    """
+    FLATTEN: Param[bool] = BooleanParam(
+        "flatten",
+        "If false, the returned table contains only a single row, otherwise, one row per feature.",
+        False
+    )
+
+    def set_flatten(self, value: bool):
+        return self.set(self.FLATTEN, value)
+
+    def get_flatten(self) -> bool:
+        return self.get(self.FLATTEN)
+
+    @property
+    def flatten(self):
+        return self.get(self.FLATTEN)
diff --git a/flink-ml-python/pyflink/ml/lib/stats/chisqtest.py b/flink-ml-python/pyflink/ml/lib/stats/chisqtest.py
index 9c29379..744d5a4 100644
--- a/flink-ml-python/pyflink/ml/lib/stats/chisqtest.py
+++ b/flink-ml-python/pyflink/ml/lib/stats/chisqtest.py
@@ -65,7 +65,7 @@ class ChiSqTest(JavaStatsAlgoOperator, _ChiSqTestParams):
     - "statistics": vector
 
     The output of this algorithm can be flattened to multiple rows by setting
-    _ChiSqTestParams#FLATTEN, which would contain the following columns:
+    HasFlatten#FLATTEN to True, which would contain the following columns:
 
     - "featureIndex": int
     - "pValue": double
diff --git a/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py b/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py
index be50063..b8ab9f0 100644
--- a/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py
+++ b/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py
@@ -21,6 +21,7 @@ import os
 import pkgutil
 import unittest
 from abc import abstractmethod
+from typing import List
 
 from pyflink.java_gateway import get_gateway
 from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
@@ -51,8 +52,11 @@ class MLLibTest(PyFlinkMLTestCase):
             FLINK_ML_LIB_SOURCE_PATH, "target", "flink-ml-lib-*SNAPSHOT.jar"))[0]
 
         StageAnalyzer = get_gateway().jvm.org.apache.flink.ml.util.StageAnalyzer
+        module_path = 'org.apache.flink.ml.{0}'.format(self.module_name())
+        excluded_stages = list(map(lambda x: f'{module_path}.{x}', self.exclude_java_stage()))
         return sorted([stage for stage in StageAnalyzer.analyzeLibJars(ml_lib_jar)
-                       if 'org.apache.flink.ml.{0}.'.format(self.module_name()) in stage])
+                       if module_path in stage
+                       and stage not in excluded_stages])
 
     def _load_python_stages(self):
         modules = [importlib.import_module('.'.join([self.module().__name__, file_name]))
@@ -81,6 +85,9 @@ class MLLibTest(PyFlinkMLTestCase):
     def module(self):
         pass
 
+    def exclude_java_stage(self):
+        return []
+
 
 class ClassificationCompletenessTest(CompletenessTest, MLLibTest):
     def module_name(self):
@@ -140,6 +147,11 @@ class StatsCompletenessTest(CompletenessTest, MLLibTest):
         from pyflink.ml.lib import stats
         return stats
 
+    def exclude_java_stage(self) -> List[str]:
+        return [
+            "anovatest.ANOVATest",
+        ]
+
 
 if __name__ == "__main__":
     try: