You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/11/30 03:45:38 UTC
[flink-ml] branch master updated: [FLINK-30160] Add Transformer for FValueTest
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new af3b42e [FLINK-30160] Add Transformer for FValueTest
af3b42e is described below
commit af3b42ee94156f4319666d4a20d0e9c1412cd34f
Author: JiangXin <ji...@alibaba-inc.com>
AuthorDate: Wed Nov 30 11:45:33 2022 +0800
[FLINK-30160] Add Transformer for FValueTest
This closes #182.
---
.../flink/ml/stats/fvaluetest/FValueTest.java | 370 ++++++++++++++++++
.../ml/stats/fvaluetest/FValueTestParams.java | 30 ++
.../org/apache/flink/ml/stats/FValueTestTest.java | 428 +++++++++++++++++++++
.../ml/lib/tests/test_ml_lib_completeness.py | 1 +
4 files changed, 829 insertions(+)
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTest.java
new file mode 100644
index 0000000..aaba3d4
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTest.java
@@ -0,0 +1,370 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats.fvaluetest;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.tuple.Tuple5;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.param.HasFlatten;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.math3.distribution.FDistribution;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An AlgoOperator which implements the F-value test algorithm.
+ *
+ * <p>The input of this algorithm is a table containing a labelColumn of numerical type and a
+ * featuresColumn of vector type. Each index in the input vector represents a feature to be tested.
+ * By default, the output of this algorithm is a table containing a single row with the following
+ * columns, each of which has one value per feature.
+ *
+ * <ul>
+ * <li>"pValues": vector
+ * <li>"degreesOfFreedom": int array
+ * <li>"fValues": vector
+ * </ul>
+ *
+ * <p>The output of this algorithm can be flattened to multiple rows by setting {@link
+ * HasFlatten#FLATTEN} to true, which would contain the following columns:
+ *
+ * <ul>
+ * <li>"featureIndex": int
+ * <li>"pValue": double
+ * <li>"degreeOfFreedom": int
+ * <li>"fValues": double
+ * </ul>
+ */
+public class FValueTest implements AlgoOperator<FValueTest>, FValueTestParams<FValueTest> {
+
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public FValueTest() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @SuppressWarnings("unchecked, rawtypes")
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+
+ final String featuresCol = getFeaturesCol();
+ final String labelCol = getLabelCol();
+ final String broadcastSummaryKey = "broadcastSummaryKey";
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ DataStream<Tuple2<Vector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ (MapFunction<Row, Tuple2<Vector, Double>>)
+ row -> {
+ Number number = (Number) row.getField(labelCol);
+ Preconditions.checkNotNull(
+ number, "Input data must contain label value.");
+ return new Tuple2<>(
+ ((Vector) row.getField(featuresCol)),
+ number.doubleValue());
+ })
+ .returns(Types.TUPLE(VectorTypeInfo.INSTANCE, Types.DOUBLE));
+
+ DataStream<Tuple5<Long, Double, Double, DenseVector, DenseVector>> summaries =
+ DataStreamUtils.aggregate(inputData, new SummaryAggregator());
+
+ DataStream<DenseVector> covarianceInEachPartition =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(inputData),
+ Collections.singletonMap(broadcastSummaryKey, summaries),
+ inputList -> {
+ DataStream input = inputList.get(0);
+ return DataStreamUtils.mapPartition(
+ input, new CalCovarianceOperator(broadcastSummaryKey));
+ });
+
+ DataStream<DenseVector> reducedCovariance =
+ DataStreamUtils.reduce(
+ covarianceInEachPartition,
+ (ReduceFunction<DenseVector>)
+ (sums1, sums2) -> {
+ BLAS.axpy(1.0, sums1, sums2);
+ return sums2;
+ });
+
+ DataStream result =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(reducedCovariance),
+ Collections.singletonMap(broadcastSummaryKey, summaries),
+ inputList -> {
+ DataStream input = inputList.get(0);
+ return DataStreamUtils.mapPartition(
+ input, new CalFValueOperator(broadcastSummaryKey));
+ });
+
+ return new Table[] {convertToTable(tEnv, result, getFlatten())};
+ }
+
+ private Table convertToTable(
+ StreamTableEnvironment tEnv,
+ DataStream<Tuple4<Integer, Double, Long, Double>> dataStream,
+ boolean flatten) {
+ if (flatten) {
+ return tEnv.fromDataStream(dataStream)
+ .as("featureIndex", "pValue", "degreeOfFreedom", "fValue");
+ } else {
+ DataStream<Tuple3<DenseVector, long[], DenseVector>> output =
+ DataStreamUtils.mapPartition(
+ dataStream,
+ new MapPartitionFunction<
+ Tuple4<Integer, Double, Long, Double>,
+ Tuple3<DenseVector, long[], DenseVector>>() {
+ @Override
+ public void mapPartition(
+ Iterable<Tuple4<Integer, Double, Long, Double>> iterable,
+ Collector<Tuple3<DenseVector, long[], DenseVector>>
+ collector) {
+ List<Tuple4<Integer, Double, Long, Double>> rows =
+ IteratorUtils.toList(iterable.iterator());
+ int numOfFeatures = rows.size();
+
+ DenseVector pValues = new DenseVector(numOfFeatures);
+ long[] degrees = new long[numOfFeatures];
+ DenseVector fValues = new DenseVector(numOfFeatures);
+
+ for (int i = 0; i < numOfFeatures; i++) {
+ Tuple4<Integer, Double, Long, Double> tuple = rows.get(i);
+ pValues.set(i, tuple.f1);
+ degrees[i] = tuple.f2;
+ fValues.set(i, tuple.f3);
+ }
+ collector.collect(Tuple3.of(pValues, degrees, fValues));
+ }
+ });
+ return tEnv.fromDataStream(output).as("pValues", "degreesOfFreedom", "fValues");
+ }
+ }
+
+ /** Computes the covariance of each feature on each partition. */
+ private static class CalCovarianceOperator
+ extends RichMapPartitionFunction<Tuple2<Vector, Double>, DenseVector> {
+
+ private final String broadcastKey;
+
+ private CalCovarianceOperator(String broadcastKey) {
+ this.broadcastKey = broadcastKey;
+ }
+
+ @Override
+ public void mapPartition(
+ Iterable<Tuple2<Vector, Double>> iterable, Collector<DenseVector> collector) {
+ Tuple5<Long, Double, Double, DenseVector, DenseVector> summaries =
+ (Tuple5<Long, Double, Double, DenseVector, DenseVector>)
+ getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+
+ int expectedNumOfFeatures = summaries.f3.size();
+ DenseVector sumVector = new DenseVector(expectedNumOfFeatures);
+ for (Tuple2<Vector, Double> featuresAndLabel : iterable) {
+ Preconditions.checkArgument(
+ featuresAndLabel.f0.size() == expectedNumOfFeatures,
+ "Input %s features, but FValueTest is expecting %s features.",
+ featuresAndLabel.f0.size(),
+ expectedNumOfFeatures);
+
+ double yDiff = featuresAndLabel.f1 - summaries.f1;
+ if (yDiff != 0) {
+ for (int i = 0; i < expectedNumOfFeatures; i++) {
+ sumVector.values[i] +=
+ yDiff * (featuresAndLabel.f0.get(i) - summaries.f3.get(i));
+ }
+ }
+ }
+ BLAS.scal(1.0 / (summaries.f0 - 1), sumVector);
+ collector.collect(sumVector);
+ }
+ }
+
+ /** Computes the p-value, fValues and the number of degrees of freedom of input features. */
+ private static class CalFValueOperator
+ extends RichMapPartitionFunction<DenseVector, Tuple4<Integer, Double, Long, Double>> {
+
+ private final String broadcastKey;
+ private DenseVector sumVector;
+
+ private CalFValueOperator(String broadcastKey) {
+ this.broadcastKey = broadcastKey;
+ }
+
+ @Override
+ public void mapPartition(
+ Iterable<DenseVector> iterable,
+ Collector<Tuple4<Integer, Double, Long, Double>> collector) {
+ Tuple5<Long, Double, Double, DenseVector, DenseVector> summaries =
+ (Tuple5<Long, Double, Double, DenseVector, DenseVector>)
+ getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+ int expectedNumOfFeatures = summaries.f4.size();
+
+ if (iterable.iterator().hasNext()) {
+ sumVector = iterable.iterator().next();
+ }
+ Preconditions.checkArgument(
+ sumVector.size() == expectedNumOfFeatures,
+ "Input %s features, but FValueTest is expecting %s features.",
+ sumVector.size(),
+ expectedNumOfFeatures);
+
+ final long numSamples = summaries.f0;
+ final long degreesOfFreedom = numSamples - 2;
+
+ FDistribution fDistribution = new FDistribution(1, degreesOfFreedom);
+ for (int i = 0; i < expectedNumOfFeatures; i++) {
+ double covariance = sumVector.get(i);
+ double corr = covariance / (summaries.f2 * summaries.f4.get(i));
+ double fValue = corr * corr / (1 - corr * corr) * degreesOfFreedom;
+ double pValue = 1.0 - fDistribution.cumulativeProbability(fValue);
+ collector.collect(Tuple4.of(i, pValue, degreesOfFreedom, fValue));
+ }
+ }
+ }
+
+ /** Computes the num, mean, and standard deviation of the input label and features. */
+ private static class SummaryAggregator
+ implements AggregateFunction<
+ Tuple2<Vector, Double>,
+ Tuple5<Long, Double, Double, DenseVector, DenseVector>,
+ Tuple5<Long, Double, Double, DenseVector, DenseVector>> {
+
+ @Override
+ public Tuple5<Long, Double, Double, DenseVector, DenseVector> createAccumulator() {
+ return Tuple5.of(
+ 0L, 0.0, 0.0, new DenseVector(new double[0]), new DenseVector(new double[0]));
+ }
+
+ @Override
+ public Tuple5<Long, Double, Double, DenseVector, DenseVector> add(
+ Tuple2<Vector, Double> featuresAndLabel,
+ Tuple5<Long, Double, Double, DenseVector, DenseVector> summary) {
+ Vector features = featuresAndLabel.f0;
+ double label = featuresAndLabel.f1;
+
+ if (summary.f0 == 0) {
+ summary.f3 = new DenseVector(features.size());
+ summary.f4 = new DenseVector(features.size());
+ }
+ summary.f0 += 1L;
+ summary.f1 += label;
+ summary.f2 += label * label;
+
+ BLAS.axpy(1.0, features, summary.f3);
+ for (int i = 0; i < features.size(); i++) {
+ summary.f4.values[i] += features.get(i) * features.get(i);
+ }
+ return summary;
+ }
+
+ @Override
+ public Tuple5<Long, Double, Double, DenseVector, DenseVector> getResult(
+ Tuple5<Long, Double, Double, DenseVector, DenseVector> summary) {
+ final long numRows = summary.f0;
+ Preconditions.checkState(numRows > 0, "The training set is empty.");
+ int numOfFeatures = summary.f3.size();
+
+ double labelMean = summary.f1 / numRows;
+ Tuple5<Long, Double, Double, DenseVector, DenseVector> result =
+ Tuple5.of(
+ numRows,
+ labelMean,
+ Math.sqrt(
+ (summary.f2 / numRows - labelMean * labelMean)
+ * numRows
+ / (numRows - 1)),
+ new DenseVector(numOfFeatures),
+ new DenseVector(numOfFeatures));
+ for (int i = 0; i < summary.f3.size(); i++) {
+ double mean = summary.f3.get(i) / numRows;
+ result.f3.values[i] = mean;
+ result.f4.values[i] =
+ Math.sqrt(
+ (summary.f4.get(i) / numRows - mean * mean)
+ * numRows
+ / (numRows - 1));
+ }
+ return result;
+ }
+
+ @Override
+ public Tuple5<Long, Double, Double, DenseVector, DenseVector> merge(
+ Tuple5<Long, Double, Double, DenseVector, DenseVector> summary1,
+ Tuple5<Long, Double, Double, DenseVector, DenseVector> summary2) {
+ if (summary1.f0 == 0) {
+ return summary2;
+ }
+ if (summary2.f0 == 0) {
+ return summary1;
+ }
+ summary2.f0 += summary1.f0;
+ summary2.f1 += summary1.f1;
+ summary2.f2 += summary1.f2;
+ BLAS.axpy(1, summary1.f3, summary2.f3);
+ BLAS.axpy(1, summary1.f4, summary2.f4);
+ return summary2;
+ }
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static FValueTest load(StreamTableEnvironment tEnv, String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTestParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTestParams.java
new file mode 100644
index 0000000..1c0239a
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/fvaluetest/FValueTestParams.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats.fvaluetest;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasFlatten;
+import org.apache.flink.ml.common.param.HasLabelCol;
+
+/**
+ * Params for {@link FValueTest}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FValueTestParams<T> extends HasFeaturesCol<T>, HasLabelCol<T>, HasFlatten<T> {}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java
new file mode 100644
index 0000000..5a31c06
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/FValueTestTest.java
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.stats.fvaluetest.FValueTest;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/** Tests the {@link FValueTest}. */
+public class FValueTestTest extends AbstractTestBase {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamTableEnvironment tEnv;
+ private Table denseInputTable;
+ private Table sparseInputTable;
+
+ private static final double EPS = 1.0e-5;
+ private static final List<Row> DENSE_INPUT_DATA =
+ Arrays.asList(
+ Row.of(
+ 0.19775997,
+ Vectors.dense(
+ 0.15266373,
+ 0.30235661,
+ 0.06203641,
+ 0.45986034,
+ 0.83525338,
+ 0.92699705)),
+ Row.of(
+ 0.66009772,
+ Vectors.dense(
+ 0.72698898,
+ 0.76849622,
+ 0.26920507,
+ 0.64402929,
+ 0.09337326,
+ 0.07968589)),
+ Row.of(
+ 0.80865842,
+ Vectors.dense(
+ 0.58961375,
+ 0.34334054,
+ 0.98887615,
+ 0.62647321,
+ 0.68177928,
+ 0.55225681)),
+ Row.of(
+ 0.34142582,
+ Vectors.dense(
+ 0.26886006,
+ 0.37325939,
+ 0.2229281,
+ 0.1864426,
+ 0.39064809,
+ 0.19316241)),
+ Row.of(
+ 0.84756607,
+ Vectors.dense(
+ 0.61091093,
+ 0.88280845,
+ 0.62233882,
+ 0.25311894,
+ 0.17993031,
+ 0.81640447)),
+ Row.of(
+ 0.53360225,
+ Vectors.dense(
+ 0.22537162,
+ 0.51685714,
+ 0.51849582,
+ 0.60037494,
+ 0.53262048,
+ 0.01331005)),
+ Row.of(
+ 0.90053371,
+ Vectors.dense(
+ 0.52409726,
+ 0.89588471,
+ 0.76990129,
+ 0.1228517,
+ 0.29587269,
+ 0.61202358)),
+ Row.of(
+ 0.78779561,
+ Vectors.dense(
+ 0.72613812,
+ 0.46349747,
+ 0.76911037,
+ 0.19163103,
+ 0.55786672,
+ 0.55077816)),
+ Row.of(
+ 0.51604647,
+ Vectors.dense(
+ 0.47222549,
+ 0.79188496,
+ 0.11524968,
+ 0.6813039,
+ 0.36233361,
+ 0.34420889)),
+ Row.of(
+ 0.35325637,
+ Vectors.dense(
+ 0.44951875,
+ 0.02694226,
+ 0.41524769,
+ 0.9222317,
+ 0.09120557,
+ 0.31512178)),
+ Row.of(
+ 0.51408926,
+ Vectors.dense(
+ 0.52802224,
+ 0.32806203,
+ 0.44891554,
+ 0.01633442,
+ 0.0970269,
+ 0.69258857)),
+ Row.of(
+ 0.84489897,
+ Vectors.dense(
+ 0.83594341,
+ 0.42432199,
+ 0.8487743,
+ 0.54679121,
+ 0.35410346,
+ 0.72724968)),
+ Row.of(
+ 0.55342816,
+ Vectors.dense(
+ 0.09385168,
+ 0.8928588,
+ 0.33625828,
+ 0.89183268,
+ 0.296849,
+ 0.30164829)),
+ Row.of(
+ 0.89405683,
+ Vectors.dense(
+ 0.80624061,
+ 0.83760997,
+ 0.63428133,
+ 0.3113273,
+ 0.02944858,
+ 0.39977732)),
+ Row.of(
+ 0.54588131,
+ Vectors.dense(
+ 0.51817346,
+ 0.00738845,
+ 0.77494778,
+ 0.8544712,
+ 0.13153282,
+ 0.28767364)),
+ Row.of(
+ 0.96038024,
+ Vectors.dense(
+ 0.32658881,
+ 0.90655956,
+ 0.99955954,
+ 0.77088429,
+ 0.04284752,
+ 0.96525111)),
+ Row.of(
+ 0.71349698,
+ Vectors.dense(
+ 0.97521246,
+ 0.2025168,
+ 0.67985305,
+ 0.46534506,
+ 0.92001748,
+ 0.72820735)),
+ Row.of(
+ 0.43456735,
+ Vectors.dense(
+ 0.24585653,
+ 0.01953996,
+ 0.70598881,
+ 0.77448287,
+ 0.4729746,
+ 0.80146736)),
+ Row.of(
+ 0.52462506,
+ Vectors.dense(
+ 0.17539792,
+ 0.72016934,
+ 0.3678759,
+ 0.53209295,
+ 0.29719397,
+ 0.37429151)),
+ Row.of(
+ 0.43074793,
+ Vectors.dense(
+ 0.72810013,
+ 0.39850784,
+ 0.1058295,
+ 0.39858265,
+ 0.52196395,
+ 0.1060125)));
+
+ private static final List<Row> SPARSE_INPUT_DATA =
+ Arrays.asList(
+ Row.of(4.6, Vectors.dense(6.0, 7.0, 0.0, 7.0, 6.0, 0.0, 0.0).toSparse()),
+ Row.of(6.6, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0, 9.0, 0.0).toSparse()),
+ Row.of(5.1, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0, 0.0).toSparse()),
+ Row.of(7.6, Vectors.dense(0.0, 9.0, 8.0, 5.0, 6.0, 4.0, 0.0).toSparse()),
+ Row.of(9.0, Vectors.dense(8.0, 9.0, 6.0, 5.0, 4.0, 4.0, 0.0).toSparse()),
+ Row.of(
+ 9.0,
+ Vectors.dense(Double.NaN, 9.0, 6.0, 4.0, 0.0, 0.0, 0.0).toSparse()));
+
+ private static final Row EXPECTED_OUTPUT_DENSE =
+ Row.of(
+ Vectors.dense(
+ 1.73658700e-02,
+ 1.49916659e-02,
+ 1.12697153e-04,
+ 4.26990301e-01,
+ 2.75911201e-01,
+ 1.93549275e-01),
+ new long[] {18, 18, 18, 18, 18, 18},
+ Vectors.dense(
+ 6.86260598,
+ 7.23175589,
+ 24.11424725,
+ 0.6605354,
+ 1.26266286,
+ 1.82421406));
+
+ private static final List<Row> EXPECTED_FLATTENED_OUTPUT_DENSE =
+ Arrays.asList(
+ Row.of(0, 1.73658700e-02, 18, 6.86260598),
+ Row.of(1, 1.49916659e-02, 18, 7.23175589),
+ Row.of(2, 1.12697153e-04, 18, 24.11424725),
+ Row.of(3, 4.26990301e-01, 18, 0.6605354),
+ Row.of(4, 2.75911201e-01, 18, 1.26266286),
+ Row.of(5, 1.93549275e-01, 18, 1.82421406));
+
+ private static final Row EXPECTED_OUTPUT_SPARSE =
+ Row.of(
+ Vectors.dense(
+ Double.NaN,
+ 0.19167161,
+ 0.06506426,
+ 0.75183662,
+ 0.16111045,
+ 0.89090362,
+ Double.NaN),
+ new long[] {4, 4, 4, 4, 4, 4, 4},
+ Vectors.dense(
+ Double.NaN,
+ 2.46254817,
+ 6.37164347,
+ 0.1147488,
+ 2.94816821,
+ 0.02134755,
+ Double.NaN));
+
+ private static final List<Row> EXPECTED_FLATTENED_OUTPUT_SPARSE =
+ Arrays.asList(
+ Row.of(0, Double.NaN, 4, Double.NaN),
+ Row.of(1, 0.19167161, 4, 2.46254817),
+ Row.of(2, 0.06506426, 4, 6.37164347),
+ Row.of(3, 0.75183662, 4, 0.1147488),
+ Row.of(4, 0.16111045, 4, 2.94816821),
+ Row.of(5, 0.89090362, 4, 0.02134755),
+ Row.of(6, Double.NaN, 4, Double.NaN));
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ env.getConfig().enableObjectReuse();
+ tEnv = StreamTableEnvironment.create(env);
+ denseInputTable =
+ tEnv.fromDataStream(env.fromCollection(DENSE_INPUT_DATA)).as("label", "features");
+ sparseInputTable =
+ tEnv.fromDataStream(env.fromCollection(SPARSE_INPUT_DATA)).as("label", "features");
+ }
+
+ private static void verifyFlattenTransformationResult(Table output, List<Row> expected)
+ throws Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+ DataStream<Row> outputDataStream = tEnv.toDataStream(output);
+ List<Row> results = IteratorUtils.toList(outputDataStream.executeAndCollect());
+ assertEquals(expected.size(), results.size());
+
+ results.sort(Comparator.comparing(r -> String.valueOf(r.getField(0))));
+ expected.sort(Comparator.comparing(r -> String.valueOf(r.getField(0))));
+
+ for (int i = 0; i < expected.size(); i++) {
+ assertEquals(expected.get(i).getArity(), results.get(i).getArity());
+ for (int j = 0; j < expected.get(i).getArity(); j++) {
+ assertEquals(
+ Double.valueOf(expected.get(i).getField(j).toString()),
+ Double.valueOf(results.get(i).getField(j).toString()),
+ EPS);
+ }
+ }
+ }
+
+ private static void verifyTransformationResult(Table output, Row expected) throws Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+ DataStream<Row> outputDataStream = tEnv.toDataStream(output);
+ List<Row> results = IteratorUtils.toList(outputDataStream.executeAndCollect());
+ assertEquals(1, results.size());
+
+ Row result = results.get(0);
+ assertEquals(3, result.getArity());
+ assertArrayEquals(
+ ((Vector) expected.getField(0)).toArray(),
+ ((Vector) result.getField(0)).toArray(),
+ EPS);
+ assertArrayEquals((long[]) expected.getField(1), (long[]) result.getField(1));
+ assertArrayEquals(
+ ((Vector) expected.getField(2)).toArray(),
+ ((Vector) result.getField(2)).toArray(),
+ EPS);
+ }
+
+ @Test
+ public void testParam() {
+ FValueTest fValueTest = new FValueTest();
+ assertEquals("label", fValueTest.getLabelCol());
+ assertEquals("features", fValueTest.getFeaturesCol());
+ assertFalse(fValueTest.getFlatten());
+
+ fValueTest.setLabelCol("test_label").setFeaturesCol("test_features").setFlatten(true);
+
+ assertEquals("test_features", fValueTest.getFeaturesCol());
+ assertEquals("test_label", fValueTest.getLabelCol());
+ assertTrue(fValueTest.getFlatten());
+ }
+
+ @Test
+ public void testOutputSchema() {
+ FValueTest fValueTest =
+ new FValueTest().setFeaturesCol("test_features").setLabelCol("test_label");
+ Table output = fValueTest.transform(denseInputTable)[0];
+ assertEquals(
+ Arrays.asList("pValues", "degreesOfFreedom", "fValues"),
+ output.getResolvedSchema().getColumnNames());
+
+ fValueTest.setFlatten(true);
+ output = fValueTest.transform(denseInputTable)[0];
+ assertEquals(
+ Arrays.asList("featureIndex", "pValue", "degreeOfFreedom", "fValue"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testTransform() throws Exception {
+ FValueTest fValueTest = new FValueTest();
+
+ Table denseOutput = fValueTest.transform(denseInputTable)[0];
+ verifyTransformationResult(denseOutput, EXPECTED_OUTPUT_DENSE);
+
+ Table sparseOutput = fValueTest.transform(sparseInputTable)[0];
+ verifyTransformationResult(sparseOutput, EXPECTED_OUTPUT_SPARSE);
+ }
+
+ @Test
+ public void testTransformWithFlatten() throws Exception {
+ FValueTest fValueTest = new FValueTest().setFlatten(true);
+
+ Table denseOutput = fValueTest.transform(denseInputTable)[0];
+ verifyFlattenTransformationResult(denseOutput, EXPECTED_FLATTENED_OUTPUT_DENSE);
+
+ Table sparseOutput = fValueTest.transform(sparseInputTable)[0];
+ verifyFlattenTransformationResult(sparseOutput, EXPECTED_FLATTENED_OUTPUT_SPARSE);
+ }
+
+ @Test
+ public void testSaveLoadAndTransform() throws Exception {
+ FValueTest fValueTest = new FValueTest();
+ FValueTest loadedFValueTest =
+ TestUtils.saveAndReload(tEnv, fValueTest, tempFolder.newFolder().getAbsolutePath());
+ Table output = loadedFValueTest.transform(denseInputTable)[0];
+ verifyTransformationResult(output, EXPECTED_OUTPUT_DENSE);
+ }
+}
diff --git a/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py b/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py
index b8ab9f0..67f8519 100644
--- a/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py
+++ b/flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py
@@ -150,6 +150,7 @@ class StatsCompletenessTest(CompletenessTest, MLLibTest):
def exclude_java_stage(self) -> List[str]:
return [
"anovatest.ANOVATest",
+ "fvaluetest.FValueTest",
]