You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/08/04 01:11:07 UTC

[flink-ml] branch master updated: [FLINK-28571] Add AlgoOperator for Chi-squared test

This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new df58dbe  [FLINK-28571] Add AlgoOperator for Chi-squared test
df58dbe is described below

commit df58dbee68131c967ce927825cd710280cfda1ce
Author: taosiyuan163 <ta...@163.com>
AuthorDate: Thu Aug 4 09:11:03 2022 +0800

    [FLINK-28571] Add AlgoOperator for Chi-squared test
    
    This closes #132.
---
 .../apache/flink/ml/stats/chisqtest/ChiSqTest.java | 657 +++++++++++++++++++++
 .../flink/ml/stats/chisqtest/ChiSqTestParams.java  |  29 +
 .../org/apache/flink/ml/stats/ChiSqTestTest.java   | 199 +++++++
 3 files changed, 885 insertions(+)

diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java
new file mode 100644
index 0000000..732989d
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java
@@ -0,0 +1,657 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats.chisqtest;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeHint;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.math3.distribution.ChiSquaredDistribution;
+
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+/**
+ * An AlgoOperator which implements the Chi-square test algorithm.
+ *
+ * <p>Chi-square Test computes the statistics of independence of variables in a contingency table,
+ * e.g., p-value, and DOF(number of degrees of freedom) for each input feature. The contingency
+ * table is constructed from the observed categorical values.
+ *
+ * <p>See: http://en.wikipedia.org/wiki/Chi-squared_test.
+ */
+public class ChiSqTest implements AlgoOperator<ChiSqTest>, ChiSqTestParams<ChiSqTest> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public ChiSqTest() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String bcCategoricalMarginsKey = "bcCategoricalMarginsKey";
+        final String bcLabelMarginsKey = "bcLabelMarginsKey";
+
+        final String[] inputCols = getInputCols();
+        String labelCol = getLabelCol();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        SingleOutputStreamOperator<Tuple3<String, Object, Object>> colAndFeatureAndLabel =
+                tEnv.toDataStream(inputs[0])
+                        .flatMap(new ExtractColAndFeatureAndLabel(inputCols, labelCol));
+
+        DataStream<Tuple4<String, Object, Object, Long>> observedFreq =
+                colAndFeatureAndLabel
+                        .keyBy(Tuple3::hashCode)
+                        .transform(
+                                "GenerateObservedFrequencies",
+                                TypeInformation.of(
+                                        new TypeHint<Tuple4<String, Object, Object, Long>>() {}),
+                                new GenerateObservedFrequencies());
+
+        SingleOutputStreamOperator<Tuple4<String, Object, Object, Long>> filledObservedFreq =
+                observedFreq
+                        .transform(
+                                "filledObservedFreq",
+                                Types.TUPLE(
+                                        Types.STRING,
+                                        Types.GENERIC(Object.class),
+                                        Types.GENERIC(Object.class),
+                                        Types.LONG),
+                                new FillFrequencyTable())
+                        .setParallelism(1);
+
+        DataStream<Tuple3<String, Object, Long>> categoricalMargins =
+                observedFreq
+                        .keyBy(tuple -> new Tuple2<>(tuple.f0, tuple.f1).hashCode())
+                        .transform(
+                                "AggregateCategoricalMargins",
+                                TypeInformation.of(new TypeHint<Tuple3<String, Object, Long>>() {}),
+                                new AggregateCategoricalMargins());
+
+        DataStream<Tuple3<String, Object, Long>> labelMargins =
+                observedFreq
+                        .keyBy(tuple -> new Tuple2<>(tuple.f0, tuple.f2).hashCode())
+                        .transform(
+                                "AggregateLabelMargins",
+                                TypeInformation.of(new TypeHint<Tuple3<String, Object, Long>>() {}),
+                                new AggregateLabelMargins());
+
+        Function<List<DataStream<?>>, DataStream<Tuple3<String, Double, Integer>>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.map(new ChiSqFunc(bcCategoricalMarginsKey, bcLabelMarginsKey));
+                };
+
+        HashMap<String, DataStream<?>> bcMap =
+                new HashMap<String, DataStream<?>>() {
+                    {
+                        put(bcCategoricalMarginsKey, categoricalMargins);
+                        put(bcLabelMarginsKey, labelMargins);
+                    }
+                };
+
+        DataStream<Tuple3<String, Double, Integer>> categoricalStatistics =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(filledObservedFreq), bcMap, function);
+
+        SingleOutputStreamOperator<Row> chiSqTestResult =
+                categoricalStatistics
+                        .transform(
+                                "chiSqTestResult",
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            Types.STRING, Types.DOUBLE, Types.DOUBLE, Types.INT
+                                        },
+                                        new String[] {
+                                            "column", "pValue", "statistic", "degreesOfFreedom"
+                                        }),
+                                new AggregateChiSqFunc())
+                        .setParallelism(1);
+
+        return new Table[] {tEnv.fromDataStream(chiSqTestResult)};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static ChiSqTest load(StreamTableEnvironment tEnv, String path) throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class ExtractColAndFeatureAndLabel
+            extends RichFlatMapFunction<Row, Tuple3<String, Object, Object>> {
+        private final String[] inputCols;
+        private final String labelCol;
+
+        public ExtractColAndFeatureAndLabel(String[] inputCols, String labelCol) {
+            this.inputCols = inputCols;
+            this.labelCol = labelCol;
+        }
+
+        @Override
+        public void flatMap(Row row, Collector<Tuple3<String, Object, Object>> collector) {
+
+            Object label = row.getFieldAs(labelCol);
+
+            for (String colName : inputCols) {
+                Object value = row.getField(colName);
+                collector.collect(new Tuple3<>(colName, value, label));
+            }
+        }
+    }
+
+    /**
+     * Computes the frequency of each feature value at different columns by labels. An output record
+     * (columnA, featureValueB, labelC, countD) represents that A feature value {featureValueB} with
+     * label {labelC} at column {columnA} has appeared {countD} times in the input table.
+     */
+    private static class GenerateObservedFrequencies
+            extends AbstractStreamOperator<Tuple4<String, Object, Object, Long>>
+            implements OneInputStreamOperator<
+                            Tuple3<String, Object, Object>, Tuple4<String, Object, Object, Long>>,
+                    BoundedOneInput {
+
+        private HashMap<Tuple3<String, Object, Object>, Long> cntMap = new HashMap<>();
+        private ListState<HashMap<Tuple3<String, Object, Object>, Long>> cntMapState;
+
+        @Override
+        public void endInput() {
+            for (Tuple3<String, Object, Object> key : cntMap.keySet()) {
+                Long count = cntMap.get(key);
+                output.collect(new StreamRecord<>(new Tuple4<>(key.f0, key.f1, key.f2, count)));
+            }
+            cntMapState.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<String, Object, Object>> element) {
+
+            Tuple3<String, Object, Object> colAndCategoryAndLabel = element.getValue();
+            cntMap.compute(colAndCategoryAndLabel, (k, v) -> (v == null ? 1 : v + 1));
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            cntMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "cntMapState",
+                                            TypeInformation.of(
+                                                    new TypeHint<
+                                                            HashMap<
+                                                                    Tuple3<String, Object, Object>,
+                                                                    Long>>() {})));
+
+            OperatorStateUtils.getUniqueElement(cntMapState, "cntMapState")
+                    .ifPresent(x -> cntMap = x);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            cntMapState.update(Collections.singletonList(cntMap));
+        }
+    }
+
+    /**
+     * Fills the frequency table by setting the frequency of missed elements (i.e., missed
+     * combinations of column, featureValue and labelValue) as zero.
+     */
+    private static class FillFrequencyTable
+            extends AbstractStreamOperator<Tuple4<String, Object, Object, Long>>
+            implements OneInputStreamOperator<
+                            Tuple4<String, Object, Object, Long>,
+                            Tuple4<String, Object, Object, Long>>,
+                    BoundedOneInput {
+
+        private HashMap<Tuple2<String, Object>, ArrayList<Tuple2<Object, Long>>> valuesMap =
+                new HashMap<>();
+        private HashSet<Object> distinctLabels = new HashSet<>();
+
+        private ListState<HashMap<Tuple2<String, Object>, ArrayList<Tuple2<Object, Long>>>>
+                valuesMapState;
+        private ListState<HashSet<Object>> distinctLabelsState;
+
+        @Override
+        public void endInput() {
+
+            for (Map.Entry<Tuple2<String, Object>, ArrayList<Tuple2<Object, Long>>> entry :
+                    valuesMap.entrySet()) {
+                ArrayList<Tuple2<Object, Long>> labelAndCountList = entry.getValue();
+                Tuple2<String, Object> categoricalKey = entry.getKey();
+
+                List<Object> existingLabels =
+                        labelAndCountList.stream().map(v -> v.f0).collect(Collectors.toList());
+
+                for (Object label : distinctLabels) {
+                    if (!existingLabels.contains(label)) {
+                        Tuple2<Object, Long> generatedLabelCount = new Tuple2<>(label, 0L);
+                        labelAndCountList.add(generatedLabelCount);
+                    }
+                }
+
+                for (Tuple2<Object, Long> labelAndCount : labelAndCountList) {
+                    output.collect(
+                            new StreamRecord<>(
+                                    new Tuple4<>(
+                                            categoricalKey.f0,
+                                            categoricalKey.f1,
+                                            labelAndCount.f0,
+                                            labelAndCount.f1)));
+                }
+            }
+
+            valuesMapState.clear();
+            distinctLabelsState.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple4<String, Object, Object, Long>> element) {
+            Tuple4<String, Object, Object, Long> colAndCategoryAndLabelAndCount =
+                    element.getValue();
+            Tuple2<String, Object> key =
+                    new Tuple2<>(
+                            colAndCategoryAndLabelAndCount.f0, colAndCategoryAndLabelAndCount.f1);
+            Tuple2<Object, Long> labelAndCount =
+                    new Tuple2<>(
+                            colAndCategoryAndLabelAndCount.f2, colAndCategoryAndLabelAndCount.f3);
+            ArrayList<Tuple2<Object, Long>> labelAndCountList = valuesMap.get(key);
+
+            if (labelAndCountList == null) {
+                ArrayList<Tuple2<Object, Long>> value = new ArrayList<>();
+                value.add(labelAndCount);
+                valuesMap.put(key, value);
+            } else {
+                labelAndCountList.add(labelAndCount);
+            }
+
+            distinctLabels.add(colAndCategoryAndLabelAndCount.f2);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            valuesMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "valuesMapState",
+                                            TypeInformation.of(
+                                                    new TypeHint<
+                                                            HashMap<
+                                                                    Tuple2<String, Object>,
+                                                                    ArrayList<
+                                                                            Tuple2<
+                                                                                    Object,
+                                                                                    Long>>>>() {})));
+            distinctLabelsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "distinctLabelsState",
+                                            TypeInformation.of(
+                                                    new TypeHint<HashSet<Object>>() {})));
+
+            OperatorStateUtils.getUniqueElement(valuesMapState, "valuesMapState")
+                    .ifPresent(x -> valuesMap = x);
+
+            OperatorStateUtils.getUniqueElement(distinctLabelsState, "distinctLabelsState")
+                    .ifPresent(x -> distinctLabels = x);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            valuesMapState.update(Collections.singletonList(valuesMap));
+            distinctLabelsState.update(Collections.singletonList(distinctLabels));
+        }
+    }
+
+    /** Computes the marginal sums of different categories. */
+    private static class AggregateCategoricalMargins
+            extends AbstractStreamOperator<Tuple3<String, Object, Long>>
+            implements OneInputStreamOperator<
+                            Tuple4<String, Object, Object, Long>, Tuple3<String, Object, Long>>,
+                    BoundedOneInput {
+
+        private HashMap<Tuple2<String, Object>, Long> categoricalMarginsMap = new HashMap<>();
+
+        private ListState<HashMap<Tuple2<String, Object>, Long>> categoricalMarginsMapState;
+
+        @Override
+        public void endInput() {
+            for (Tuple2<String, Object> key : categoricalMarginsMap.keySet()) {
+                Long categoricalMargin = categoricalMarginsMap.get(key);
+                output.collect(new StreamRecord<>(new Tuple3<>(key.f0, key.f1, categoricalMargin)));
+            }
+            categoricalMarginsMap.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple4<String, Object, Object, Long>> element) {
+
+            Tuple4<String, Object, Object, Long> colAndCategoryAndLabelAndCnt = element.getValue();
+            Tuple2<String, Object> key =
+                    new Tuple2<>(colAndCategoryAndLabelAndCnt.f0, colAndCategoryAndLabelAndCnt.f1);
+            Long observedFreq = colAndCategoryAndLabelAndCnt.f3;
+            categoricalMarginsMap.compute(
+                    key, (k, v) -> (v == null ? observedFreq : v + observedFreq));
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            categoricalMarginsMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "categoricalMarginsMapState",
+                                            TypeInformation.of(
+                                                    new TypeHint<
+                                                            HashMap<
+                                                                    Tuple2<String, Object>,
+                                                                    Long>>() {})));
+
+            OperatorStateUtils.getUniqueElement(
+                            categoricalMarginsMapState, "categoricalMarginsMapState")
+                    .ifPresent(x -> categoricalMarginsMap = x);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            categoricalMarginsMapState.update(Collections.singletonList(categoricalMarginsMap));
+        }
+    }
+
+    /** Computes the marginal sums of different labels. */
+    private static class AggregateLabelMargins
+            extends AbstractStreamOperator<Tuple3<String, Object, Long>>
+            implements OneInputStreamOperator<
+                            Tuple4<String, Object, Object, Long>, Tuple3<String, Object, Long>>,
+                    BoundedOneInput {
+
+        private HashMap<Tuple2<String, Object>, Long> labelMarginsMap = new HashMap<>();
+        private ListState<HashMap<Tuple2<String, Object>, Long>> labelMarginsMapState;
+
+        @Override
+        public void endInput() {
+
+            for (Tuple2<String, Object> key : labelMarginsMap.keySet()) {
+                Long labelMargin = labelMarginsMap.get(key);
+                output.collect(new StreamRecord<>(new Tuple3<>(key.f0, key.f1, labelMargin)));
+            }
+            labelMarginsMapState.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple4<String, Object, Object, Long>> element) {
+
+            Tuple4<String, Object, Object, Long> colAndFeatureAndLabelAndCnt = element.getValue();
+            Long observedFreq = colAndFeatureAndLabelAndCnt.f3;
+            Tuple2<String, Object> key =
+                    new Tuple2<>(colAndFeatureAndLabelAndCnt.f0, colAndFeatureAndLabelAndCnt.f2);
+
+            labelMarginsMap.compute(key, (k, v) -> (v == null ? observedFreq : v + observedFreq));
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            labelMarginsMapState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "labelMarginsMapState",
+                                            TypeInformation.of(
+                                                    new TypeHint<
+                                                            HashMap<
+                                                                    Tuple2<String, Object>,
+                                                                    Long>>() {})));
+
+            OperatorStateUtils.getUniqueElement(labelMarginsMapState, "labelMarginsMapState")
+                    .ifPresent(x -> labelMarginsMap = x);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            labelMarginsMapState.update(Collections.singletonList(labelMarginsMap));
+        }
+    }
+
+    /** Conduct Pearson's independence test on the input contingency table. */
+    private static class ChiSqFunc
+            extends RichMapFunction<
+                    Tuple4<String, Object, Object, Long>, Tuple3<String, Double, Integer>> {
+
+        private final String bcCategoricalMarginsKey;
+        private final String bcLabelMarginsKey;
+        private final Map<Tuple2<String, Object>, Long> categoricalMargins = new HashMap<>();
+        private final Map<Tuple2<String, Object>, Long> labelMargins = new HashMap<>();
+
+        double sampleSize = 0;
+        int numLabels = 0;
+        HashMap<String, Integer> col2NumCategories = new HashMap<>();
+
+        public ChiSqFunc(String bcCategoricalMarginsKey, String bcLabelMarginsKey) {
+            this.bcCategoricalMarginsKey = bcCategoricalMarginsKey;
+            this.bcLabelMarginsKey = bcLabelMarginsKey;
+        }
+
+        @Override
+        public Tuple3<String, Double, Integer> map(Tuple4<String, Object, Object, Long> v) {
+            if (categoricalMargins.isEmpty()) {
+                List<Tuple3<String, Object, Long>> categoricalMarginList =
+                        getRuntimeContext().getBroadcastVariable(bcCategoricalMarginsKey);
+                List<Tuple3<String, Object, Long>> labelMarginList =
+                        getRuntimeContext().getBroadcastVariable(bcLabelMarginsKey);
+
+                for (Tuple3<String, Object, Long> colAndFeatureAndCount : categoricalMarginList) {
+                    String theColName = colAndFeatureAndCount.f0;
+                    col2NumCategories.merge(theColName, 1, Integer::sum);
+                }
+
+                numLabels = (int) labelMarginList.stream().map(x -> x.f1).distinct().count();
+
+                for (Tuple3<String, Object, Long> colAndFeatureAndCount : categoricalMarginList) {
+                    categoricalMargins.put(
+                            new Tuple2<>(colAndFeatureAndCount.f0, colAndFeatureAndCount.f1),
+                            colAndFeatureAndCount.f2);
+                }
+
+                Map<String, Double> sampleSizeCount = new HashMap<>();
+                String tmpKey = null;
+
+                for (Tuple3<String, Object, Long> colAndLabelAndCount : labelMarginList) {
+                    String col = colAndLabelAndCount.f0;
+
+                    if (tmpKey == null) {
+                        tmpKey = col;
+                        sampleSizeCount.put(col, 0D);
+                    }
+
+                    sampleSizeCount.computeIfPresent(
+                            col, (k, count) -> count + colAndLabelAndCount.f2);
+                    labelMargins.put(
+                            new Tuple2<>(col, colAndLabelAndCount.f1), colAndLabelAndCount.f2);
+                }
+
+                Optional<Double> sampleSizeOpt =
+                        sampleSizeCount.values().stream().reduce(Double::sum);
+                Preconditions.checkArgument(sampleSizeOpt.isPresent());
+                sampleSize = sampleSizeOpt.get();
+            }
+
+            String colName = v.f0;
+            // Degrees of freedom
+            int dof = (col2NumCategories.get(colName) - 1) * (numLabels - 1);
+
+            Tuple2<String, Object> category = new Tuple2<>(v.f0, v.f1);
+
+            Tuple2<String, Object> colAndLabelKey = new Tuple2<>(v.f0, v.f2);
+            Long theCategoricalMargin = categoricalMargins.get(category);
+            Long theLabelMargin = labelMargins.get(colAndLabelKey);
+            Long observed = v.f3;
+
+            double expected = (double) (theLabelMargin * theCategoricalMargin) / sampleSize;
+            double categoricalStatistic = pearsonFunc(observed, expected);
+
+            return new Tuple3<>(colName, categoricalStatistic, dof);
+        }
+
+        // Pearson's chi-squared test: http://en.wikipedia.org/wiki/Pearson%27s_chi-squared_test
+        private double pearsonFunc(double observed, double expected) {
+            double dev = observed - expected;
+            return dev * dev / expected;
+        }
+    }
+
+    /**
+     * Computes the Pearson's chi-squared statistic, p-value and the number of degrees of freedom
+     * for every feature across the input DataStream.
+     */
+    private static class AggregateChiSqFunc extends AbstractStreamOperator<Row>
+            implements OneInputStreamOperator<Tuple3<String, Double, Integer>, Row>,
+                    BoundedOneInput {
+
+        private HashMap<String, Tuple2<Double, Integer>> col2Statistic = new HashMap<>();
+        private ListState<HashMap<String, Tuple2<Double, Integer>>> col2StatisticState;
+
+        @Override
+        public void endInput() {
+
+            for (Map.Entry<String, Tuple2<Double, Integer>> entry : col2Statistic.entrySet()) {
+                String colName = entry.getKey();
+                Tuple2<Double, Integer> statisticAndCof = entry.getValue();
+                Double statistic = statisticAndCof.f0;
+                Integer dof = statisticAndCof.f1;
+                double pValue = 1;
+                if (dof == 0) {
+                    statistic = 0D;
+                } else {
+                    pValue = 1.0 - new ChiSquaredDistribution(dof).cumulativeProbability(statistic);
+                }
+
+                double pValueScaled =
+                        new BigDecimal(pValue).setScale(11, RoundingMode.HALF_UP).doubleValue();
+                double statisticScaled =
+                        new BigDecimal(statistic).setScale(11, RoundingMode.HALF_UP).doubleValue();
+
+                output.collect(
+                        new StreamRecord<>(Row.of(colName, pValueScaled, statisticScaled, dof)));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<String, Double, Integer>> element) {
+            Tuple3<String, Double, Integer> colAndStatisticAndDof = element.getValue();
+            String colName = colAndStatisticAndDof.f0;
+            Double partialStatistic = colAndStatisticAndDof.f1;
+            Integer dof = colAndStatisticAndDof.f2;
+
+            col2Statistic.merge(
+                    colName,
+                    new Tuple2<>(partialStatistic, dof),
+                    (thisOne, otherOne) -> {
+                        thisOne.f0 += otherOne.f0;
+                        return thisOne;
+                    });
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            col2StatisticState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "col2StatisticState",
+                                            TypeInformation.of(
+                                                    new TypeHint<
+                                                            HashMap<
+                                                                    String,
+                                                                    Tuple2<
+                                                                            Double,
+                                                                            Integer>>>() {})));
+
+            OperatorStateUtils.getUniqueElement(col2StatisticState, "col2StatisticState")
+                    .ifPresent(x -> col2Statistic = x);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            col2StatisticState.update(Collections.singletonList(col2Statistic));
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
new file mode 100644
index 0000000..882bfc9
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTestParams.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats.chisqtest;
+
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasLabelCol;
+
+/**
+ * Params for {@link ChiSqTest}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface ChiSqTestParams<T> extends HasInputCols<T>, HasLabelCol<T> {}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ChiSqTestTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ChiSqTestTest.java
new file mode 100644
index 0000000..57c5693
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/stats/ChiSqTestTest.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.stats;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.stats.chisqtest.ChiSqTest;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests the {@link ChiSqTest}. */
+public class ChiSqTestTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamTableEnvironment tEnv;
+    private Table inputTableWithDoubleLabel;
+    private Table inputTableWithIntegerLabel;
+    private Table inputTableWithStringLabel;
+
+    private final List<Row> samplesWithDoubleLabel =
+            Arrays.asList(
+                    Row.of(0., 5, 1.),
+                    Row.of(2., 6, 2.),
+                    Row.of(1., 7, 2.),
+                    Row.of(1., 5, 4.),
+                    Row.of(0., 5, 1.),
+                    Row.of(2., 6, 2.),
+                    Row.of(1., 7, 2.),
+                    Row.of(1., 5, 4.),
+                    Row.of(2., 5, 1.),
+                    Row.of(0., 5, 2.),
+                    Row.of(0., 5, 2.),
+                    Row.of(1., 9, 4.),
+                    Row.of(1., 9, 3.));
+
+    private final List<Row> expectedChiSqTestResultWithDoubleLabel =
+            Arrays.asList(
+                    Row.of("f1", 0.03419350755, 13.61904761905, 6),
+                    Row.of("f2", 0.24220177737, 7.94444444444, 6));
+
+    private final List<Row> samplesWithIntegerLabel =
+            Arrays.asList(
+                    Row.of(33, 5, "a"),
+                    Row.of(44, 6, "b"),
+                    Row.of(55, 7, "b"),
+                    Row.of(11, 5, "b"),
+                    Row.of(11, 5, "a"),
+                    Row.of(33, 6, "c"),
+                    Row.of(22, 7, "c"),
+                    Row.of(66, 5, "d"),
+                    Row.of(77, 5, "d"),
+                    Row.of(88, 5, "f"),
+                    Row.of(77, 5, "h"),
+                    Row.of(44, 9, "h"),
+                    Row.of(11, 9, "j"));
+
+    private final List<Row> expectedChiSqTestResultWithIntegerLabel =
+            Arrays.asList(
+                    Row.of("f1", 0.35745138256, 22.75, 21),
+                    Row.of("f2", 0.39934987096, 43.69444444444, 42));
+
+    private final List<Row> samplesWithStringLabel =
+            Arrays.asList(
+                    Row.of("v1", 11, 21.22),
+                    Row.of("v1", 33, 22.33),
+                    Row.of("v2", 22, 32.44),
+                    Row.of("v3", 11, 54.22),
+                    Row.of("v3", 33, 22.22),
+                    Row.of("v2", 22, 22.22),
+                    Row.of("v4", 55, 22.22),
+                    Row.of("v5", 11, 41.11),
+                    Row.of("v6", 55, 14.41),
+                    Row.of("v7", 13, 25.55),
+                    Row.of("v8", 14, 25.55),
+                    Row.of("v9", 14, 44.44),
+                    Row.of("v9", 14, 31.11));
+
+    private final List<Row> expectedChiSqTestResultWithStringLabel =
+            Arrays.asList(
+                    Row.of("f1", 0.06672255089, 54.16666666667, 40),
+                    Row.of("f2", 0.42335512313, 73.66666666667, 72));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        inputTableWithDoubleLabel =
+                tEnv.fromDataStream(env.fromCollection(samplesWithDoubleLabel))
+                        .as("label", "f1", "f2");
+        inputTableWithIntegerLabel =
+                tEnv.fromDataStream(env.fromCollection(samplesWithIntegerLabel))
+                        .as("label", "f1", "f2");
+        inputTableWithStringLabel =
+                tEnv.fromDataStream(env.fromCollection(samplesWithStringLabel))
+                        .as("label", "f1", "f2");
+    }
+
+    private static void verifyPredictionResult(Table output, List<Row> expected) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+        DataStream<Row> outputDataStream = tEnv.toDataStream(output);
+
+        List<Row> result = IteratorUtils.toList(outputDataStream.executeAndCollect());
+
+        compareResultCollections(
+                expected,
+                result,
+                (row1, row2) -> {
+                    if (!row1.equals(row2)) {
+                        return 1;
+                    } else {
+                        return 0;
+                    }
+                });
+    }
+
+    @Test
+    public void testParam() {
+        ChiSqTest chiSqTest = new ChiSqTest();
+        assertEquals("label", chiSqTest.getLabelCol());
+
+        chiSqTest.setInputCols("f1", "f2").setLabelCol("click");
+        assertArrayEquals(new String[] {"f1", "f2"}, chiSqTest.getInputCols());
+        assertEquals("click", chiSqTest.getLabelCol());
+    }
+
+    @Test
+    public void testOutputSchema() {
+        ChiSqTest chiSqTest = new ChiSqTest().setInputCols("f1", "f2").setLabelCol("label");
+
+        Table output = chiSqTest.transform(inputTableWithDoubleLabel)[0];
+        assertEquals(
+                Arrays.asList("column", "pValue", "statistic", "degreesOfFreedom"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testTransform() throws Exception {
+        ChiSqTest chiSqTest = new ChiSqTest().setInputCols("f1", "f2").setLabelCol("label");
+
+        Table output1 = chiSqTest.transform(inputTableWithDoubleLabel)[0];
+        verifyPredictionResult(output1, expectedChiSqTestResultWithDoubleLabel);
+
+        Table output2 = chiSqTest.transform(inputTableWithIntegerLabel)[0];
+        verifyPredictionResult(output2, expectedChiSqTestResultWithIntegerLabel);
+
+        Table output3 = chiSqTest.transform(inputTableWithStringLabel)[0];
+        verifyPredictionResult(output3, expectedChiSqTestResultWithStringLabel);
+    }
+
+    @Test
+    public void testSaveLoadAndTransform() throws Exception {
+        ChiSqTest chiSqTest = new ChiSqTest().setInputCols("f1", "f2").setLabelCol("label");
+
+        ChiSqTest loadedChiSqTest =
+                TestUtils.saveAndReload(tEnv, chiSqTest, tempFolder.newFolder().getAbsolutePath());
+        Table output1 = loadedChiSqTest.transform(inputTableWithDoubleLabel)[0];
+        verifyPredictionResult(output1, expectedChiSqTestResultWithDoubleLabel);
+    }
+}