You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/11/28 14:09:02 UTC
[flink-ml] branch master updated: [FLINK-30159] Add Transformer for ANOVATest
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new d6a5b42 [FLINK-30159] Add Transformer for ANOVATest
d6a5b42 is described below
commit d6a5b42aae2a1a84789dfe0a35c7e97461b5d418
Author: JiangXin <ji...@alibaba-inc.com>
AuthorDate: Mon Nov 28 22:08:56 2022 +0800
[FLINK-30159] Add Transformer for ANOVATest
This closes #180.
---
.../param/HasFlatten.java} | 14 +-
.../apache/flink/ml/stats/anovatest/ANOVATest.java | 287 +++++++++++++++
.../ANOVATestParams.java} | 23 +-
.../apache/flink/ml/stats/chisqtest/ChiSqTest.java | 3 +-
.../flink/ml/stats/chisqtest/ChiSqTestParams.java | 19 +-
.../org/apache/flink/ml/stats/ANOVATestTest.java | 410 +++++++++++++++++++++
.../pyflink/ml/core/tests/test_param.py | 16 +-
flink-ml-python/pyflink/ml/lib/param.py | 23 +-
flink-ml-python/pyflink/ml/lib/stats/chisqtest.py | 2 +-
.../ml/lib/tests/test_ml_lib_completeness.py | 14 +-
10 files changed, 760 insertions(+), 51 deletions(-)
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFlatten.java
similarity index 79%
copy from flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
copy to flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFlatten.java
index e9d3ec9..899e7a5 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFlatten.java
@@ -16,19 +16,15 @@
* limitations under the License.
*/
-package org.apache.flink.ml.stats.chisqtest;
+package org.apache.flink.ml.common.param;
-import org.apache.flink.ml.common.param.HasFeaturesCol;
-import org.apache.flink.ml.common.param.HasLabelCol;
import org.apache.flink.ml.param.BooleanParam;
import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared flatten param. */
+public interface HasFlatten<T> extends WithParams<T> {
-/**
- * Params for {@link ChiSqTest}.
- *
- * @param <T> The class type of this instance.
- */
-public interface ChiSqTestParams<T> extends HasFeaturesCol<T>, HasLabelCol<T> {
Param<Boolean> FLATTEN =
new BooleanParam(
"flatten",
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java
new file mode 100644
index 0000000..b58fdc8
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATest.java
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats.anovatest;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.param.HasFlatten;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.math3.distribution.FDistribution;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.IntStream;
+
+/**
+ * An AlgoOperator which implements the ANOVA test algorithm.
+ *
+ * <p>See <a href="https://en.wikipedia.org/wiki/Analysis_of_variance">Wikipedia</a> for more
+ * information on ANOVA test.
+ *
+ * <p>The input of this algorithm is a table containing a labelColumn of numerical type and a
+ * featuresColumn of vector type. Each index in the input vector represents a feature to be tested.
+ * By default, the output of this algorithm is a table containing a single row with the following
+ * columns, each of which has one value per feature.
+ *
+ * <ul>
+ * <li>"pValues": vector
+ * <li>"degreesOfFreedom": int array
+ * <li>"fValues": vector
+ * </ul>
+ *
+ * <p>The output of this algorithm can be flattened to multiple rows by setting {@link
+ * HasFlatten#FLATTEN} to true, which would contain the following columns:
+ *
+ * <ul>
+ * <li>"featureIndex": int
+ * <li>"pValue": double
+ * <li>"degreeOfFreedom": int
+ * <li>"fValues": double
+ * </ul>
+ */
+public class ANOVATest implements AlgoOperator<ANOVATest>, ANOVATestParams<ANOVATest> {
+
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public ANOVATest() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+
+ final String featuresCol = getFeaturesCol();
+ final String labelCol = getLabelCol();
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ DataStream<Tuple2<Vector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ (MapFunction<Row, Tuple2<Vector, Double>>)
+ row -> {
+ Number number = (Number) row.getField(labelCol);
+ Preconditions.checkNotNull(
+ number, "Input data must contain label value.");
+ return new Tuple2<>(
+ ((Vector) row.getField(featuresCol)),
+ number.doubleValue());
+ })
+ .returns(Types.TUPLE(VectorTypeInfo.INSTANCE, Types.DOUBLE));
+ DataStream<List<Row>> streamWithANOVA =
+ DataStreamUtils.aggregate(inputData, new ANOVAAggregator());
+ return new Table[] {convertToTable(tEnv, streamWithANOVA, getFlatten())};
+ }
+
+ /** Computes the p-value, fValues and the number of degrees of freedom of input features. */
+ @SuppressWarnings("unchecked")
+ private static class ANOVAAggregator
+ implements AggregateFunction<
+ Tuple2<Vector, Double>,
+ Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[],
+ List<Row>> {
+ @Override
+ public Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] createAccumulator() {
+ return new Tuple3[0];
+ }
+
+ @Override
+ public Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] add(
+ Tuple2<Vector, Double> featuresAndLabel,
+ Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] acc) {
+ Vector features = featuresAndLabel.f0;
+ double label = featuresAndLabel.f1;
+ int numOfFeatures = features.size();
+ if (acc.length == 0) {
+ acc = new Tuple3[features.size()];
+ for (int i = 0; i < numOfFeatures; i++) {
+ acc[i] = Tuple3.of(0.0, 0.0, new HashMap<>());
+ }
+ }
+ for (int i = 0; i < numOfFeatures; i++) {
+ double featureValue = features.get(i);
+ acc[i].f0 += featureValue;
+ acc[i].f1 += featureValue * featureValue;
+
+ if (acc[i].f2.containsKey(label)) {
+ acc[i].f2.get(label).f0 += featureValue;
+ acc[i].f2.get(label).f1 += 1L;
+ } else {
+ acc[i].f2.put(label, Tuple2.of(featureValue, 1L));
+ }
+ }
+ return acc;
+ }
+
+ @Override
+ public List<Row> getResult(
+ Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] acc) {
+ List<Row> results = new ArrayList<>();
+ for (int i = 0; i < acc.length; i++) {
+ Tuple3<Double, Long, Double> resultOfANOVA =
+ computeANOVA(acc[i].f0, acc[i].f1, acc[i].f2);
+ results.add(Row.of(i, resultOfANOVA.f0, resultOfANOVA.f1, resultOfANOVA.f2));
+ }
+ return results;
+ }
+
+ @Override
+ public Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] merge(
+ Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] acc1,
+ Tuple3<Double, Double, HashMap<Double, Tuple2<Double, Long>>>[] acc2) {
+ if (acc1.length == 0) {
+ return acc2;
+ }
+ if (acc2.length == 0) {
+ return acc1;
+ }
+ IntStream.range(0, acc1.length)
+ .forEach(
+ i -> {
+ acc2[i].f0 += acc1[i].f0;
+ acc2[i].f1 += acc1[i].f1;
+ acc1[i].f2.forEach(
+ (k, v) -> {
+ if (acc2[i].f2.containsKey(k)) {
+ acc2[i].f2.get(k).f0 += v.f0;
+ acc2[i].f2.get(k).f1 += v.f1;
+ } else {
+ acc2[i].f2.put(k, v);
+ }
+ });
+ });
+ return acc2;
+ }
+
+ private Tuple3<Double, Long, Double> computeANOVA(
+ double sum, double sumOfSq, HashMap<Double, Tuple2<Double, Long>> summary) {
+ long numOfClasses = summary.size();
+
+ long numOfSamples = summary.values().stream().mapToLong(t -> t.f1).sum();
+
+ double sqSum = sum * sum;
+
+ double ssTot = sumOfSq - sqSum / numOfSamples;
+
+ double totalSqSum = 0;
+ for (Tuple2<Double, Long> t : summary.values()) {
+ totalSqSum += t.f0 * t.f0 / t.f1;
+ }
+
+ double sumOfSqBetween = totalSqSum - (sqSum / numOfSamples);
+
+ double sumOfSqWithin = ssTot - sumOfSqBetween;
+
+ long degreeOfFreedomBetween = numOfClasses - 1;
+ Preconditions.checkArgument(
+ degreeOfFreedomBetween > 0, "Num of classes should be positive.");
+
+ long degreeOfFreedomWithin = numOfSamples - numOfClasses;
+ Preconditions.checkArgument(
+ degreeOfFreedomWithin > 0,
+ "Num of samples should be greater than num of classes.");
+
+ double meanSqBetween = sumOfSqBetween / degreeOfFreedomBetween;
+
+ double meanSqWithin = sumOfSqWithin / degreeOfFreedomWithin;
+
+ double fValue = meanSqBetween / meanSqWithin;
+
+ FDistribution fd = new FDistribution(degreeOfFreedomBetween, degreeOfFreedomWithin);
+ double pValue = 1 - fd.cumulativeProbability(fValue);
+
+ long degreeOfFreedom = degreeOfFreedomBetween + degreeOfFreedomWithin;
+
+ return Tuple3.of(pValue, degreeOfFreedom, fValue);
+ }
+ }
+
+ private Table convertToTable(
+ StreamTableEnvironment tEnv, DataStream<List<Row>> datastream, boolean flatten) {
+ if (flatten) {
+ DataStream<Row> output =
+ datastream
+ .flatMap(
+ (FlatMapFunction<List<Row>, Row>)
+ (list, collector) -> list.forEach(collector::collect))
+ .setParallelism(1)
+ .returns(Types.ROW(Types.INT, Types.DOUBLE, Types.LONG, Types.DOUBLE));
+ return tEnv.fromDataStream(output)
+ .as("featureIndex", "pValue", "degreeOfFreedom", "fValue");
+ } else {
+ DataStream<Tuple3<DenseVector, long[], DenseVector>> output =
+ datastream.map(
+ new MapFunction<List<Row>, Tuple3<DenseVector, long[], DenseVector>>() {
+ @Override
+ public Tuple3<DenseVector, long[], DenseVector> map(
+ List<Row> rows) {
+ int numOfFeatures = rows.size();
+ DenseVector pValues = new DenseVector(numOfFeatures);
+ DenseVector fValues = new DenseVector(numOfFeatures);
+ long[] degrees = new long[numOfFeatures];
+
+ for (int i = 0; i < numOfFeatures; i++) {
+ Row row = rows.get(i);
+ pValues.set(i, (double) row.getField(1));
+ degrees[i] = (long) row.getField(2);
+ fValues.set(i, (double) row.getField(3));
+ }
+ return Tuple3.of(pValues, degrees, fValues);
+ }
+ });
+ return tEnv.fromDataStream(output).as("pValues", "degreesOfFreedom", "fValues");
+ }
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static ANOVATest load(StreamTableEnvironment tEnv, String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATestParams.java
similarity index 60%
copy from flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
copy to flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATestParams.java
index e9d3ec9..2ac44a8 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/anovatest/ANOVATestParams.java
@@ -16,30 +16,15 @@
* limitations under the License.
*/
-package org.apache.flink.ml.stats.chisqtest;
+package org.apache.flink.ml.stats.anovatest;
import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasFlatten;
import org.apache.flink.ml.common.param.HasLabelCol;
-import org.apache.flink.ml.param.BooleanParam;
-import org.apache.flink.ml.param.Param;
/**
- * Params for {@link ChiSqTest}.
+ * Params for {@link ANOVATest}.
*
* @param <T> The class type of this instance.
*/
-public interface ChiSqTestParams<T> extends HasFeaturesCol<T>, HasLabelCol<T> {
- Param<Boolean> FLATTEN =
- new BooleanParam(
- "flatten",
- "If false, the returned table contains only a single row, otherwise, one row per feature.",
- false);
-
- default boolean getFlatten() {
- return get(FLATTEN);
- }
-
- default T setFlatten(boolean value) {
- return set(FLATTEN, value);
- }
-}
+public interface ANOVATestParams<T> extends HasFeaturesCol<T>, HasLabelCol<T>, HasFlatten<T> {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java
index 7e625ff..ea479db 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java
@@ -32,6 +32,7 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.AlgoOperator;
import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.param.HasFlatten;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
@@ -87,7 +88,7 @@ import java.util.stream.Collectors;
* </ul>
*
* <p>The output of this algorithm can be flattened to multiple rows by setting {@link
- * ChiSqTestParams#FLATTEN}, which would contain the following columns:
+ * HasFlatten#FLATTEN} to true, which would contain the following columns:
*
* <ul>
* <li>"featureIndex": int
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
index e9d3ec9..4ad1cb4 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
@@ -19,27 +19,12 @@
package org.apache.flink.ml.stats.chisqtest;
import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasFlatten;
import org.apache.flink.ml.common.param.HasLabelCol;
-import org.apache.flink.ml.param.BooleanParam;
-import org.apache.flink.ml.param.Param;
/**
* Params for {@link ChiSqTest}.
*
* @param <T> The class type of this instance.
*/
-public interface ChiSqTestParams<T> extends HasFeaturesCol<T>, HasLabelCol<T> {
- Param<Boolean> FLATTEN =
- new BooleanParam(
- "flatten",
- "If false, the returned table contains only a single row, otherwise, one row per feature.",
- false);
-
- default boolean getFlatten() {
- return get(FLATTEN);
- }
-
- default T setFlatten(boolean value) {
- return set(FLATTEN, value);
- }
-}
+public interface ChiSqTestParams<T> extends HasFeaturesCol<T>, HasLabelCol<T>, HasFlatten<T> {}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java
new file mode 100644
index 0000000..0abdec7
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ANOVATestTest.java
@@ -0,0 +1,410 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.stats.anovatest.ANOVATest;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/** Tests the {@link ANOVATest}. */
+public class ANOVATestTest extends AbstractTestBase {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamTableEnvironment tEnv;
+ private Table denseInputTable;
+ private Table sparseInputTable;
+
+ private static final double EPS = 1.0e-5;
+ private static final List<Row> DENSE_INPUT_DATA =
+ Arrays.asList(
+ Row.of(
+ 3,
+ Vectors.dense(
+ 0.85956061,
+ 0.1645695,
+ 0.48347596,
+ 0.92102727,
+ 0.42855644,
+ 0.05746009)),
+ Row.of(
+ 2,
+ Vectors.dense(
+ 0.92500743,
+ 0.65760154,
+ 0.13295284,
+ 0.53344893,
+ 0.8994776,
+ 0.24836496)),
+ Row.of(
+ 1,
+ Vectors.dense(
+ 0.03017182,
+ 0.07244715,
+ 0.87416449,
+ 0.55843035,
+ 0.91604736,
+ 0.63346045)),
+ Row.of(
+ 5,
+ Vectors.dense(
+ 0.28325261,
+ 0.36536881,
+ 0.09223386,
+ 0.37251258,
+ 0.34742278,
+ 0.70517077)),
+ Row.of(
+ 4,
+ Vectors.dense(
+ 0.64850904,
+ 0.04090877,
+ 0.21173176,
+ 0.00148992,
+ 0.13897166,
+ 0.21182539)),
+ Row.of(
+ 4,
+ Vectors.dense(
+ 0.02609493,
+ 0.44608735,
+ 0.23910531,
+ 0.95449222,
+ 0.90763182,
+ 0.8624905)),
+ Row.of(
+ 5,
+ Vectors.dense(
+ 0.09158744,
+ 0.97745235,
+ 0.41150139,
+ 0.45830467,
+ 0.52590925,
+ 0.29441554)),
+ Row.of(
+ 4,
+ Vectors.dense(
+ 0.97211594,
+ 0.1814442,
+ 0.30340642,
+ 0.17445413,
+ 0.52756958,
+ 0.02069296)),
+ Row.of(
+ 2,
+ Vectors.dense(
+ 0.06354593,
+ 0.63527231,
+ 0.49620335,
+ 0.0141264,
+ 0.62722219,
+ 0.63497507)),
+ Row.of(
+ 1,
+ Vectors.dense(
+ 0.10814149,
+ 0.8296426,
+ 0.51775217,
+ 0.57068344,
+ 0.54633305,
+ 0.12714921)),
+ Row.of(
+ 1,
+ Vectors.dense(
+ 0.72731796,
+ 0.94010124,
+ 0.45007811,
+ 0.87650674,
+ 0.53735565,
+ 0.49568415)),
+ Row.of(
+ 2,
+ Vectors.dense(
+ 0.41827208,
+ 0.85100628,
+ 0.38685271,
+ 0.60689503,
+ 0.21784097,
+ 0.91294433)),
+ Row.of(
+ 3,
+ Vectors.dense(
+ 0.65843656,
+ 0.5880859,
+ 0.18862706,
+ 0.856398,
+ 0.18029327,
+ 0.94851926)),
+ Row.of(
+ 4,
+ Vectors.dense(
+ 0.3841634,
+ 0.25138793,
+ 0.96746644,
+ 0.77048045,
+ 0.44685196,
+ 0.19813854)),
+ Row.of(
+ 5,
+ Vectors.dense(
+ 0.65982267,
+ 0.23024125,
+ 0.13598434,
+ 0.60144265,
+ 0.57848927,
+ 0.85623564)),
+ Row.of(
+ 1,
+ Vectors.dense(
+ 0.35764189,
+ 0.47623815,
+ 0.5459232,
+ 0.79508298,
+ 0.14462443,
+ 0.01802919)),
+ Row.of(
+ 5,
+ Vectors.dense(
+ 0.38532153,
+ 0.90614554,
+ 0.86629571,
+ 0.13988735,
+ 0.32062385,
+ 0.00179492)),
+ Row.of(
+ 3,
+ Vectors.dense(
+ 0.2142368,
+ 0.28306022,
+ 0.59481646,
+ 0.42567028,
+ 0.52207663,
+ 0.78082401)),
+ Row.of(
+ 1,
+ Vectors.dense(
+ 0.20788283,
+ 0.76861782,
+ 0.59595468,
+ 0.62103642,
+ 0.17781246,
+ 0.77655345)),
+ Row.of(
+ 1,
+ Vectors.dense(
+ 0.1751708,
+ 0.4547537,
+ 0.46187865,
+ 0.79781199,
+ 0.05104487,
+ 0.42406092)));
+
+ private static final List<Row> SPARSE_INPUT_DATA =
+ Arrays.asList(
+ Row.of(3, Vectors.dense(6.0, 7.0, 0.0, 7.0, 6.0, 0.0, 0.0).toSparse()),
+ Row.of(1, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0, 9.0, 0.0).toSparse()),
+ Row.of(3, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0, 0.0).toSparse()),
+ Row.of(2, Vectors.dense(0.0, 9.0, 8.0, 5.0, 6.0, 4.0, 0.0).toSparse()),
+ Row.of(2, Vectors.dense(8.0, 9.0, 6.0, 5.0, 4.0, 4.0, 0.0).toSparse()),
+ Row.of(3, Vectors.dense(Double.NaN, 9.0, 6.0, 4.0, 0.0, 0.0, 0.0).toSparse()));
+
+ private static final Row EXPECTED_OUTPUT_DENSE =
+ Row.of(
+ Vectors.dense(
+ 0.64137831, 0.14830724, 0.69858474, 0.28038169, 0.86759161, 0.81608606),
+ new long[] {19, 19, 19, 19, 19, 19},
+ Vectors.dense(
+ 0.64110932, 1.98689258, 0.55499714, 1.40340562, 0.30881722, 0.3848595));
+
+ private static final List<Row> EXPECTED_FLATTENED_OUTPUT_DENSE =
+ Arrays.asList(
+ Row.of(0, 0.64137831, 19, 0.64110932),
+ Row.of(1, 0.14830724, 19, 1.98689258),
+ Row.of(2, 0.69858474, 19, 0.55499714),
+ Row.of(3, 0.28038169, 19, 1.40340562),
+ Row.of(4, 0.86759161, 19, 0.30881722),
+ Row.of(5, 0.81608606, 19, 0.3848595));
+
+ private static final Row EXPECTED_OUTPUT_SPARSE =
+ Row.of(
+ Vectors.dense(
+ Double.NaN,
+ 0.71554175,
+ 0.34278574,
+ 0.45824059,
+ 0.84633632,
+ 0.15673368,
+ Double.NaN),
+ new long[] {5, 5, 5, 5, 5, 5, 5},
+ Vectors.dense(
+ Double.NaN, 0.375, 1.5625, 1.02364865, 0.17647059, 3.66, Double.NaN));
+
+ private static final List<Row> EXPECTED_FLATTENED_OUTPUT_SPARSE =
+ Arrays.asList(
+ Row.of(0, Double.NaN, 5, Double.NaN),
+ Row.of(1, 0.71554175, 5, 0.375),
+ Row.of(2, 0.34278574, 5, 1.5625),
+ Row.of(3, 0.45824059, 5, 1.02364865),
+ Row.of(4, 0.84633632, 5, 0.17647059),
+ Row.of(5, 0.15673368, 5, 3.66),
+ Row.of(6, Double.NaN, 5, Double.NaN));
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ env.getConfig().enableObjectReuse();
+ tEnv = StreamTableEnvironment.create(env);
+ denseInputTable =
+ tEnv.fromDataStream(env.fromCollection(DENSE_INPUT_DATA)).as("label", "features");
+ sparseInputTable =
+ tEnv.fromDataStream(env.fromCollection(SPARSE_INPUT_DATA)).as("label", "features");
+ }
+
+ private static void verifyFlattenTransformationResult(Table output, List<Row> expected)
+ throws Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+ DataStream<Row> outputDataStream = tEnv.toDataStream(output);
+ List<Row> results = IteratorUtils.toList(outputDataStream.executeAndCollect());
+ assertEquals(expected.size(), results.size());
+
+ results.sort(Comparator.comparing(r -> String.valueOf(r.getField(0))));
+ expected.sort(Comparator.comparing(r -> String.valueOf(r.getField(0))));
+
+ for (int i = 0; i < expected.size(); i++) {
+ assertEquals(expected.get(i).getArity(), results.get(i).getArity());
+ for (int j = 0; j < expected.get(i).getArity(); j++) {
+ assertEquals(
+ Double.valueOf(expected.get(i).getField(j).toString()),
+ Double.valueOf(results.get(i).getField(j).toString()),
+ EPS);
+ }
+ }
+ }
+
+ private static void verifyTransformationResult(Table output, Row expected) throws Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+ DataStream<Row> outputDataStream = tEnv.toDataStream(output);
+ List<Row> results = IteratorUtils.toList(outputDataStream.executeAndCollect());
+ assertEquals(1, results.size());
+
+ Row result = results.get(0);
+ assertEquals(3, result.getArity());
+ assertArrayEquals(
+ ((Vector) expected.getField(0)).toArray(),
+ ((Vector) result.getField(0)).toArray(),
+ EPS);
+ assertArrayEquals((long[]) expected.getField(1), (long[]) result.getField(1));
+ assertArrayEquals(
+ ((Vector) expected.getField(2)).toArray(),
+ ((Vector) result.getField(2)).toArray(),
+ EPS);
+ }
+
+ @Test
+ public void testParam() {
+ ANOVATest anovaTest = new ANOVATest();
+ assertEquals("label", anovaTest.getLabelCol());
+ assertEquals("features", anovaTest.getFeaturesCol());
+ assertFalse(anovaTest.getFlatten());
+
+ anovaTest.setLabelCol("test_label").setFeaturesCol("test_features").setFlatten(true);
+
+ assertEquals("test_features", anovaTest.getFeaturesCol());
+ assertEquals("test_label", anovaTest.getLabelCol());
+ assertTrue(anovaTest.getFlatten());
+ }
+
+ @Test
+ public void testOutputSchema() {
+ ANOVATest anovaTest =
+ new ANOVATest().setFeaturesCol("test_features").setLabelCol("test_label");
+ Table output = anovaTest.transform(denseInputTable)[0];
+ assertEquals(
+ Arrays.asList("pValues", "degreesOfFreedom", "fValues"),
+ output.getResolvedSchema().getColumnNames());
+
+ anovaTest.setFlatten(true);
+ output = anovaTest.transform(denseInputTable)[0];
+ assertEquals(
+ Arrays.asList("featureIndex", "pValue", "degreeOfFreedom", "fValue"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testTransform() throws Exception {
+ ANOVATest anovaTest = new ANOVATest();
+
+ Table denseOutput = anovaTest.transform(denseInputTable)[0];
+ verifyTransformationResult(denseOutput, EXPECTED_OUTPUT_DENSE);
+
+ Table sparseOutput = anovaTest.transform(sparseInputTable)[0];
+ verifyTransformationResult(sparseOutput, EXPECTED_OUTPUT_SPARSE);
+ }
+
+ @Test
+ public void testTransformWithFlatten() throws Exception {
+ ANOVATest anovaTest = new ANOVATest().setFlatten(true);
+
+ Table denseOutput = anovaTest.transform(denseInputTable)[0];
+ verifyFlattenTransformationResult(denseOutput, EXPECTED_FLATTENED_OUTPUT_DENSE);
+
+ Table sparseOutput = anovaTest.transform(sparseInputTable)[0];
+ verifyFlattenTransformationResult(sparseOutput, EXPECTED_FLATTENED_OUTPUT_SPARSE);
+ }
+
+ @Test
+ public void testSaveLoadAndTransform() throws Exception {
+ ANOVATest anovaTest = new ANOVATest();
+ ANOVATest loadedANOVATest =
+ TestUtils.saveAndReload(tEnv, anovaTest, tempFolder.newFolder().getAbsolutePath());
+ Table output = loadedANOVATest.transform(denseInputTable)[0];
+ verifyTransformationResult(output, EXPECTED_OUTPUT_DENSE);
+ }
+}
diff --git a/flink-ml-python/pyflink/ml/core/tests/test_param.py b/flink-ml-python/pyflink/ml/core/tests/test_param.py
index 4b1e40e..7213fb9 100644
--- a/flink-ml-python/pyflink/ml/core/tests/test_param.py
+++ b/flink-ml-python/pyflink/ml/core/tests/test_param.py
@@ -22,7 +22,7 @@ from pyflink.ml.core.param import Param
from pyflink.ml.lib.param import HasDistanceMeasure, HasFeaturesCol, HasGlobalBatchSize, \
HasHandleInvalid, HasInputCols, HasLabelCol, HasLearningRate, HasMaxIter, HasMultiClass, \
HasOutputCols, HasPredictionCol, HasRawPredictionCol, HasReg, HasSeed, HasTol, HasWeightCol, \
- HasWindows, HasRelativeError
+ HasWindows, HasRelativeError, HasFlatten
from pyflink.ml.core.windows import GlobalWindows, CountTumblingWindows
@@ -30,7 +30,7 @@ from pyflink.ml.core.windows import GlobalWindows, CountTumblingWindows
class TestParams(HasDistanceMeasure, HasFeaturesCol, HasGlobalBatchSize, HasHandleInvalid,
HasInputCols, HasLabelCol, HasLearningRate, HasMaxIter, HasMultiClass,
HasOutputCols, HasPredictionCol, HasRawPredictionCol, HasReg, HasSeed, HasTol,
- HasWeightCol, HasWindows, HasRelativeError):
+ HasWeightCol, HasWindows, HasRelativeError, HasFlatten):
def __init__(self):
self._param_map = {}
@@ -227,3 +227,15 @@ class ParamTests(unittest.TestCase):
param.set_relative_error(0.1)
self.assertEqual(param.get_relative_error(), 0.1)
+
+ def test_flatten(self):
+ param = TestParams()
+ flatten = param.FLATTEN
+ self.assertEqual(flatten.name, "flatten")
+ self.assertEqual(flatten.description,
+ "If false, the returned table contains only a "
+ "single row, otherwise, one row per feature.")
+ self.assertFalse(flatten.default_value)
+
+ param.set_flatten(True)
+ self.assertTrue(param.get_flatten())
diff --git a/flink-ml-python/pyflink/ml/lib/param.py b/flink-ml-python/pyflink/ml/lib/param.py
index 4ca3aa0..e230e04 100644
--- a/flink-ml-python/pyflink/ml/lib/param.py
+++ b/flink-ml-python/pyflink/ml/lib/param.py
@@ -19,7 +19,7 @@ from abc import ABC
from typing import Tuple
from pyflink.ml.core.param import WithParams, Param, ParamValidators, StringParam, IntParam, \
- StringArrayParam, FloatParam, WindowsParam
+ StringArrayParam, FloatParam, WindowsParam, BooleanParam
from pyflink.ml.core.windows import Windows, GlobalWindows
@@ -560,3 +560,24 @@ class HasRelativeError(WithParams, ABC):
@property
def relative_error(self):
return self.get(self.RELATIVE_ERROR)
+
+
+class HasFlatten(WithParams, ABC):
+ """
+ Interface for shared flatten param.
+ """
+ FLATTEN: Param[bool] = BooleanParam(
+ "flatten",
+ "If false, the returned table contains only a single row, otherwise, one row per feature.",
+ False
+ )
+
+ def set_flatten(self, value: bool):
+ return self.set(self.FLATTEN, value)
+
+ def get_flatten(self) -> bool:
+ return self.get(self.FLATTEN)
+
+ @property
+ def flatten(self):
+ return self.get(self.FLATTEN)
diff --git a/flink-ml-python/pyflink/ml/lib/stats/chisqtest.py b/flink-ml-python/pyflink/ml/lib/stats/chisqtest.py
index 9c29379..744d5a4 100644
--- a/flink-ml-python/pyflink/ml/lib/stats/chisqtest.py
+++ b/flink-ml-python/pyflink/ml/lib/stats/chisqtest.py
@@ -65,7 +65,7 @@ class ChiSqTest(JavaStatsAlgoOperator, _ChiSqTestParams):
- "statistics": vector
The output of this algorithm can be flattened to multiple rows by setting
- _ChiSqTestParams#FLATTEN, which would contain the following columns:
+ HasFlatten#FLATTEN to True, which would contain the following columns:
- "featureIndex": int
- "pValue": double
diff --git a/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py b/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py
index be50063..b8ab9f0 100644
--- a/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py
+++ b/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py
@@ -21,6 +21,7 @@ import os
import pkgutil
import unittest
from abc import abstractmethod
+from typing import List
from pyflink.java_gateway import get_gateway
from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
@@ -51,8 +52,11 @@ class MLLibTest(PyFlinkMLTestCase):
FLINK_ML_LIB_SOURCE_PATH, "target", "flink-ml-lib-*SNAPSHOT.jar"))[0]
StageAnalyzer = get_gateway().jvm.org.apache.flink.ml.util.StageAnalyzer
+ module_path = 'org.apache.flink.ml.{0}'.format(self.module_name())
+ excluded_stages = list(map(lambda x: f'{module_path}.{x}', self.exclude_java_stage()))
return sorted([stage for stage in StageAnalyzer.analyzeLibJars(ml_lib_jar)
- if 'org.apache.flink.ml.{0}.'.format(self.module_name()) in stage])
+ if module_path in stage
+ and stage not in excluded_stages])
def _load_python_stages(self):
modules = [importlib.import_module('.'.join([self.module().__name__, file_name]))
@@ -81,6 +85,9 @@ class MLLibTest(PyFlinkMLTestCase):
def module(self):
pass
+ def exclude_java_stage(self):
+ return []
+
class ClassificationCompletenessTest(CompletenessTest, MLLibTest):
def module_name(self):
@@ -140,6 +147,11 @@ class StatsCompletenessTest(CompletenessTest, MLLibTest):
from pyflink.ml.lib import stats
return stats
+ def exclude_java_stage(self) -> List[str]:
+ return [
+ "anovatest.ANOVATest",
+ ]
+
if __name__ == "__main__":
try: