You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/08/17 06:28:18 UTC

[flink-ml] branch master updated: [FLINK-28803] Add Transformer and Estimator for KBinsDiscretizer

This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new be546ad  [FLINK-28803] Add Transformer and Estimator for KBinsDiscretizer
be546ad is described below

commit be546ad47887bdf2bbcf66b78d33bbbc5cdf643b
Author: Zhipeng Zhang <zh...@gmail.com>
AuthorDate: Wed Aug 17 14:28:14 2022 +0800

    [FLINK-28803] Add Transformer and Estimator for KBinsDiscretizer
    
    This closes #139.
---
 .../examples/feature/KBinsDiscretizerExample.java  |  71 +++++
 .../feature/kbinsdiscretizer/KBinsDiscretizer.java | 341 +++++++++++++++++++++
 .../kbinsdiscretizer/KBinsDiscretizerModel.java    | 172 +++++++++++
 .../KBinsDiscretizerModelData.java                 | 128 ++++++++
 .../KBinsDiscretizerModelParams.java               |  29 ++
 .../kbinsdiscretizer/KBinsDiscretizerParams.java   |  85 +++++
 .../ml/feature/minmaxscaler/MinMaxScaler.java      |   2 +-
 .../flink/ml/feature/KBinsDiscretizerTest.java     | 285 +++++++++++++++++
 .../ml/feature/kbinsdiscreteizer_example.py        |  75 +++++
 .../pyflink/ml/lib/feature/kbinsdiscretizer.py     | 168 ++++++++++
 .../ml/lib/feature/tests/test_kbinsdiscretizer.py  | 172 +++++++++++
 11 files changed, 1527 insertions(+), 1 deletion(-)

diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java
new file mode 100644
index 0000000..1478f8c
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer;
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel;
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that trains a KBinsDiscretizer model and uses it for feature engineering. */
+public class KBinsDiscretizerExample {
+    public static void main(String[] args) {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        // Generates input data.
+        DataStream<Row> inputStream =
+                env.fromElements(
+                        Row.of(Vectors.dense(1, 10, 0)),
+                        Row.of(Vectors.dense(1, 10, 0)),
+                        Row.of(Vectors.dense(1, 10, 0)),
+                        Row.of(Vectors.dense(4, 10, 0)),
+                        Row.of(Vectors.dense(5, 10, 0)),
+                        Row.of(Vectors.dense(6, 10, 0)),
+                        Row.of(Vectors.dense(7, 10, 0)),
+                        Row.of(Vectors.dense(10, 10, 0)),
+                        Row.of(Vectors.dense(13, 10, 3)));
+        Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+        // Creates a KBinsDiscretizer object and initializes its parameters.
+        KBinsDiscretizer kBinsDiscretizer =
+                new KBinsDiscretizer().setNumBins(3).setStrategy(KBinsDiscretizerParams.UNIFORM);
+
+        // Trains the KBinsDiscretizer Model.
+        KBinsDiscretizerModel model = kBinsDiscretizer.fit(inputTable);
+
+        // Uses the KBinsDiscretizer Model for predictions.
+        Table outputTable = model.transform(inputTable)[0];
+
+        // Extracts and displays the results.
+        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+            Row row = it.next();
+            DenseVector inputValue = (DenseVector) row.getField(kBinsDiscretizer.getInputCol());
+            DenseVector outputValue = (DenseVector) row.getField(kBinsDiscretizer.getOutputCol());
+            System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java
new file mode 100644
index 0000000..accae29
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java
@@ -0,0 +1,341 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.kbinsdiscretizer;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler.MinMaxReduceFunctionOperator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * An Estimator which implements discretization (also known as quantization or binning) to transform
+ * continuous features into discrete ones. The output values are in [0, numBins).
+ *
+ * <p>KBinsDiscretizer implements three different binning strategies, and it can be set by {@link
+ * KBinsDiscretizerParams#STRATEGY}. If the strategy is set as {@link KBinsDiscretizerParams#KMEANS}
+ * or {@link KBinsDiscretizerParams#QUANTILE}, users should further set {@link
+ * KBinsDiscretizerParams#SUB_SAMPLES} for better performance.
+ *
+ * <p>There are several corner cases for different inputs as listed below:
+ *
+ * <ul>
+ *   <li>When the input values of one column are all the same, then they should be mapped to the
+ *       same bin (i.e., the zero-th bin). Thus the corresponding bin edges are `{Double.MIN_VALUE,
+ *       Double.MAX_VALUE}`.
+ *   <li>When the number of distinct values of one column is less than the specified number of bins
+ *       and the {@link KBinsDiscretizerParams#STRATEGY} is set as {@link
+ *       KBinsDiscretizerParams#KMEANS}, we switch to {@link KBinsDiscretizerParams#UNIFORM}.
+ *   <li>When the width of one output bin is zero, i.e., the left edge equals to the right edge of
+ *       the bin, we remove it.
+ * </ul>
+ */
+public class KBinsDiscretizer
+        implements Estimator<KBinsDiscretizer, KBinsDiscretizerModel>,
+                KBinsDiscretizerParams<KBinsDiscretizer> {
+    private static final Logger LOG = LoggerFactory.getLogger(KBinsDiscretizer.class);
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public KBinsDiscretizer() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public KBinsDiscretizerModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        String inputCol = getInputCol();
+        String strategy = getStrategy();
+        int numBins = getNumBins();
+
+        DataStream<DenseVector> inputData =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                (MapFunction<Row, DenseVector>)
+                                        value -> ((Vector) value.getField(inputCol)).toDense());
+
+        DataStream<DenseVector> preprocessedData;
+        if (strategy.equals(UNIFORM)) {
+            preprocessedData =
+                    inputData
+                            .transform(
+                                    "reduceInEachPartition",
+                                    inputData.getType(),
+                                    new MinMaxReduceFunctionOperator())
+                            .transform(
+                                    "reduceInFinalPartition",
+                                    inputData.getType(),
+                                    new MinMaxReduceFunctionOperator())
+                            .setParallelism(1);
+        } else {
+            preprocessedData =
+                    DataStreamUtils.sample(
+                            inputData, getSubSamples(), getClass().getName().hashCode());
+        }
+
+        DataStream<KBinsDiscretizerModelData> modelData =
+                DataStreamUtils.mapPartition(
+                        preprocessedData,
+                        new MapPartitionFunction<DenseVector, KBinsDiscretizerModelData>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<DenseVector> iterable,
+                                    Collector<KBinsDiscretizerModelData> collector) {
+                                List<DenseVector> list = new ArrayList<>();
+                                iterable.iterator().forEachRemaining(list::add);
+
+                                if (list.size() == 0) {
+                                    throw new RuntimeException("The training set is empty.");
+                                }
+
+                                double[][] binEdges;
+                                switch (strategy) {
+                                    case UNIFORM:
+                                        binEdges = findBinEdgesWithUniformStrategy(list, numBins);
+                                        break;
+                                    case QUANTILE:
+                                        binEdges = findBinEdgesWithQuantileStrategy(list, numBins);
+                                        break;
+                                    case KMEANS:
+                                        binEdges = findBinEdgesWithKMeansStrategy(list, numBins);
+                                        break;
+                                    default:
+                                        throw new UnsupportedOperationException(
+                                                "Unsupported "
+                                                        + STRATEGY
+                                                        + " type: "
+                                                        + strategy
+                                                        + ".");
+                                }
+
+                                collector.collect(new KBinsDiscretizerModelData(binEdges));
+                            }
+                        });
+        modelData.getTransformation().setParallelism(1);
+
+        KBinsDiscretizerModel model =
+                new KBinsDiscretizerModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, getParamMap());
+        return model;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static KBinsDiscretizer load(StreamTableEnvironment tEnv, String path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    private static double[][] findBinEdgesWithUniformStrategy(
+            List<DenseVector> input, int numBins) {
+        DenseVector minVector = input.get(0);
+        DenseVector maxVector = input.get(1);
+        int numColumns = minVector.size();
+        double[][] binEdges = new double[numColumns][];
+
+        for (int columnId = 0; columnId < numColumns; columnId++) {
+            double min = minVector.get(columnId);
+            double max = maxVector.get(columnId);
+            if (min == max) {
+                LOG.warn("Feature " + columnId + " is constant and the output will all be zero.");
+                binEdges[columnId] = new double[] {Double.MIN_VALUE, Double.MAX_VALUE};
+            } else {
+                double width = (max - min) / numBins;
+                binEdges[columnId] = new double[numBins + 1];
+                binEdges[columnId][0] = min;
+                for (int edgeId = 1; edgeId < numBins + 1; edgeId++) {
+                    binEdges[columnId][edgeId] = binEdges[columnId][edgeId - 1] + width;
+                }
+            }
+        }
+        return binEdges;
+    }
+
+    private static double[][] findBinEdgesWithQuantileStrategy(
+            List<DenseVector> input, int numBins) {
+        int numColumns = input.get(0).size();
+        int numData = input.size();
+        double[][] binEdges = new double[numColumns][];
+        double[] features = new double[numData];
+
+        for (int columnId = 0; columnId < numColumns; columnId++) {
+            for (int i = 0; i < numData; i++) {
+                features[i] = input.get(i).get(columnId);
+            }
+            Arrays.sort(features);
+
+            if (features[0] == features[numData - 1]) {
+                LOG.warn("Feature " + columnId + " is constant and the output will all be zero.");
+                binEdges[columnId] = new double[] {Double.MIN_VALUE, Double.MAX_VALUE};
+            } else {
+                double width = 1.0 * features.length / numBins;
+                double[] tempBinEdges = new double[numBins + 1];
+
+                for (int binEdgeId = 0; binEdgeId < numBins; binEdgeId++) {
+                    tempBinEdges[binEdgeId] = features[(int) (binEdgeId * width)];
+                }
+                tempBinEdges[numBins] = features[numData - 1];
+
+                // Removes bins that are empty, i.e., the left edge equals to the right edge.
+                Set<Double> edges = new HashSet<>(numBins);
+                for (double edge : tempBinEdges) {
+                    edges.add(edge);
+                }
+
+                binEdges[columnId] = edges.stream().mapToDouble(Double::doubleValue).toArray();
+                Arrays.sort(binEdges[columnId]);
+            }
+        }
+        return binEdges;
+    }
+
+    private static double[][] findBinEdgesWithKMeansStrategy(List<DenseVector> input, int numBins) {
+        int numColumns = input.get(0).size();
+        int numData = input.size();
+        double[][] binEdges = new double[numColumns][numBins + 1];
+        double[] features = new double[numData];
+
+        double[] kMeansCentroids = new double[numBins];
+        double[] sumByCluster = new double[numBins];
+
+        for (int columnId = 0; columnId < numColumns; columnId++) {
+            for (int i = 0; i < numData; i++) {
+                features[i] = input.get(i).get(columnId);
+            }
+            Arrays.sort(features);
+
+            if (features[0] == features[numData - 1]) {
+                LOG.warn("Feature " + columnId + " is constant and the output will all be zero.");
+                binEdges[columnId] = new double[] {Double.MIN_VALUE, Double.MAX_VALUE};
+            } else {
+                // Checks whether there are more than {numBins} distinct feature values in each
+                // column.
+                // If the number of distinct values is less than {numBins + 1}, then we do not need
+                // to conduct KMeans. Instead, we switch to using {@link
+                // KBinsDiscretizerParams#UNIFORM} for binning.
+                Set<Double> distinctFeatureValues = new HashSet<>(numBins + 1);
+                for (double feature : features) {
+                    distinctFeatureValues.add(feature);
+                    if (distinctFeatureValues.size() >= numBins + 1) {
+                        break;
+                    }
+                }
+                if (distinctFeatureValues.size() <= numBins) {
+                    double min = features[0];
+                    double max = features[features.length - 1];
+                    double width = (max - min) / numBins;
+                    binEdges[columnId] = new double[numBins + 1];
+                    binEdges[columnId][0] = min;
+                    for (int edgeId = 1; edgeId < numBins + 1; edgeId++) {
+                        binEdges[columnId][edgeId] = binEdges[columnId][edgeId - 1] + width;
+                    }
+                    continue;
+                } else {
+                    // Conducts KMeans here.
+                    double width = 1.0 * features.length / numBins;
+                    for (int clusterId = 0; clusterId < numBins; clusterId++) {
+                        kMeansCentroids[clusterId] = features[(int) (clusterId * width)];
+                    }
+
+                    // Default values for KMeans.
+                    final double tolerance = 1e-4;
+                    final int maxIterations = 300;
+
+                    double oldLoss = Double.MAX_VALUE;
+                    double relativeLoss = Double.MAX_VALUE;
+                    int iter = 0;
+                    int[] countByCluster = new int[numBins];
+                    while (iter < maxIterations && relativeLoss > tolerance) {
+                        double loss = 0;
+                        for (double featureValue : features) {
+                            double minDistance = Math.abs(kMeansCentroids[0] - featureValue);
+                            int clusterId = 0;
+                            for (int i = 1; i < kMeansCentroids.length; i++) {
+                                double distance = Math.abs(kMeansCentroids[i] - featureValue);
+                                if (distance < minDistance) {
+                                    minDistance = distance;
+                                    clusterId = i;
+                                }
+                            }
+                            countByCluster[clusterId]++;
+                            sumByCluster[clusterId] += featureValue;
+                            loss += minDistance;
+                        }
+
+                        // Updates cluster.
+                        for (int clusterId = 0; clusterId < kMeansCentroids.length; clusterId++) {
+                            kMeansCentroids[clusterId] =
+                                    sumByCluster[clusterId] / countByCluster[clusterId];
+                        }
+                        loss /= features.length;
+                        relativeLoss = Math.abs(loss - oldLoss);
+                        oldLoss = loss;
+                        iter++;
+                        Arrays.fill(sumByCluster, 0);
+                        Arrays.fill(countByCluster, 0);
+                    }
+
+                    Arrays.sort(kMeansCentroids);
+                    binEdges[columnId] = new double[numBins + 1];
+                    binEdges[columnId][0] = features[0];
+                    binEdges[columnId][numBins] = features[features.length - 1];
+                    for (int binEdgeId = 1; binEdgeId < numBins; binEdgeId++) {
+                        binEdges[columnId][binEdgeId] =
+                                (kMeansCentroids[binEdgeId - 1] + kMeansCentroids[binEdgeId]) / 2;
+                    }
+                }
+            }
+        }
+        return binEdges;
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java
new file mode 100644
index 0000000..7053e31
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.kbinsdiscretizer;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which transforms continuous features into discrete features using the model data computed
+ * by {@link KBinsDiscretizer}.
+ *
+ * <p>A feature value `v` should be mapped to a bin with edges as `{left, right}` if `v` is in
+ * `[left, right)`. If `v` does not fall into any of the bins, it is mapped to the closest bin. For
+ * example suppose the bin edges are `{-1, 0, 1}` for one column, then we have two bins `{-1, 0}`
+ * and `{0, 1}`. In this case, -2 is mapped into 0-th bin, 0 is mapped into the 1-st bin and 2 is
+ * mapped into the 1-st bin.
+ */
+public class KBinsDiscretizerModel
+        implements Model<KBinsDiscretizerModel>,
+                KBinsDiscretizerModelParams<KBinsDiscretizerModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public KBinsDiscretizerModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Row> inputData = tEnv.toDataStream(inputs[0]);
+        DataStream<KBinsDiscretizerModelData> modelData =
+                KBinsDiscretizerModelData.getModelDataStream(modelDataTable);
+
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                TypeInformation.of(DenseVector.class)),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(inputData),
+                        Collections.singletonMap(broadcastModelKey, modelData),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new FindBinFunction(getInputCol(), broadcastModelKey),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public KBinsDiscretizerModel setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KBinsDiscretizerModelData.getModelDataStream(modelDataTable),
+                path,
+                new KBinsDiscretizerModelData.ModelDataEncoder());
+    }
+
+    public static KBinsDiscretizerModel load(StreamTableEnvironment tEnv, String path)
+            throws IOException {
+        KBinsDiscretizerModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(
+                        tEnv, path, new KBinsDiscretizerModelData.ModelDataDecoder());
+        return model.setModelData(modelDataTable);
+    }
+
+    private static class FindBinFunction extends RichMapFunction<Row, Row> {
+        private final String inputCol;
+        private final String broadcastKey;
+        /** Model data used to find bins for each feature. */
+        private double[][] binEdges;
+
+        public FindBinFunction(String inputCol, String broadcastKey) {
+            this.inputCol = inputCol;
+            this.broadcastKey = broadcastKey;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (binEdges == null) {
+                KBinsDiscretizerModelData modelData =
+                        (KBinsDiscretizerModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+                binEdges = modelData.binEdges;
+            }
+            DenseVector inputVec = ((Vector) row.getField(inputCol)).toDense();
+            DenseVector outputVec = inputVec.clone();
+            for (int i = 0; i < inputVec.size(); i++) {
+                double targetFeature = inputVec.get(i);
+                int index = Arrays.binarySearch(binEdges[i], targetFeature);
+                if (index < 0) {
+                    // Computes the index to insert.
+                    index = -index - 1;
+                    // Puts it in the left bin.
+                    index--;
+                }
+                // Handles the boundary.
+                index = Math.min(index, (binEdges[i].length - 2));
+                index = Math.max(index, 0);
+
+                outputVec.set(i, index);
+            }
+            return Row.join(row, Row.of(outputVec));
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModelData.java
new file mode 100644
index 0000000..8dd402a
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModelData.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.kbinsdiscretizer;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.array.DoublePrimitiveArraySerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link KBinsDiscretizerModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to a data stream, and
+ * classes to save/load model data.
+ */
+public class KBinsDiscretizerModelData {
+    /**
+     * The edges of bins for each column, e.g., binEdges[0] is the edges for features at 0-th
+     * dimension.
+     */
+    public double[][] binEdges;
+
+    public KBinsDiscretizerModelData() {}
+
+    public KBinsDiscretizerModelData(double[][] binEdges) {
+        this.binEdges = binEdges;
+    }
+
+    /**
+     * Converts the table model to a data stream.
+     *
+     * @param modelDataTable The table model data.
+     * @return The data stream model data.
+     */
+    public static DataStream<KBinsDiscretizerModelData> getModelDataStream(Table modelDataTable) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+        return tEnv.toDataStream(modelDataTable)
+                .map(x -> new KBinsDiscretizerModelData((double[][]) x.getField(0)));
+    }
+
+    /** Encoder for {@link KBinsDiscretizerModelData}. */
+    public static class ModelDataEncoder implements Encoder<KBinsDiscretizerModelData> {
+
+        @Override
+        public void encode(KBinsDiscretizerModelData modelData, OutputStream outputStream)
+                throws IOException {
+            DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream);
+            IntSerializer.INSTANCE.serialize(modelData.binEdges.length, dataOutputView);
+
+            DoublePrimitiveArraySerializer doubleArraySerializer =
+                    DoublePrimitiveArraySerializer.INSTANCE;
+            for (double[] binEdge : modelData.binEdges) {
+                doubleArraySerializer.serialize(binEdge, dataOutputView);
+            }
+        }
+    }
+
+    /** Decoder for {@link KBinsDiscretizerModelData}. */
+    public static class ModelDataDecoder extends SimpleStreamFormat<KBinsDiscretizerModelData> {
+        @Override
+        public Reader<KBinsDiscretizerModelData> createReader(
+                Configuration config, FSDataInputStream stream) {
+            return new Reader<KBinsDiscretizerModelData>() {
+
+                @Override
+                public KBinsDiscretizerModelData read() throws IOException {
+                    DataInputView source = new DataInputViewStreamWrapper(stream);
+                    try {
+                        int numColumns = IntSerializer.INSTANCE.deserialize(source);
+                        double[][] binEdges = new double[numColumns][];
+
+                        DoublePrimitiveArraySerializer doubleArraySerializer =
+                                DoublePrimitiveArraySerializer.INSTANCE;
+                        for (int i = 0; i < numColumns; i++) {
+                            binEdges[i] = doubleArraySerializer.deserialize(source);
+                        }
+
+                        return new KBinsDiscretizerModelData(binEdges);
+                    } catch (EOFException e) {
+                        return null;
+                    }
+                }
+
+                @Override
+                public void close() throws IOException {
+                    stream.close();
+                }
+            };
+        }
+
+        @Override
+        public TypeInformation<KBinsDiscretizerModelData> getProducedType() {
+            return TypeInformation.of(KBinsDiscretizerModelData.class);
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModelParams.java
new file mode 100644
index 0000000..3e0fd3d
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModelParams.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.kbinsdiscretizer;
+
+import org.apache.flink.ml.common.param.HasInputCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
+
+/**
+ * Params for {@link KBinsDiscretizerModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface KBinsDiscretizerModelParams<T> extends HasInputCol<T>, HasOutputCol<T> {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerParams.java
new file mode 100644
index 0000000..c1f0b00
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerParams.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.kbinsdiscretizer;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params for {@link KBinsDiscretizer}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface KBinsDiscretizerParams<T> extends KBinsDiscretizerModelParams<T> {
+    String UNIFORM = "uniform";
+    String QUANTILE = "quantile";
+    String KMEANS = "kmeans";
+
+    /**
+     * Supported options to define the widths of the bins are listed as follows.
+     *
+     * <ul>
+     *   <li>uniform: all bins in each feature have identical widths.
+     *   <li>quantile: all bins in each feature have the same number of points.
+     *   <li>kmeans: values in each bin have the same nearest center of a 1D kmeans cluster.
+     * </ul>
+     */
+    Param<String> STRATEGY =
+            new StringParam(
+                    "strategy",
+                    "Strategy used to define the width of the bin.",
+                    QUANTILE,
+                    ParamValidators.inArray(UNIFORM, QUANTILE, KMEANS));
+
+    Param<Integer> NUM_BINS =
+            new IntParam("numBins", "Number of bins to produce.", 5, ParamValidators.gtEq(2));
+
+    Param<Integer> SUB_SAMPLES =
+            new IntParam(
+                    "subSamples",
+                    "Maximum number of samples used to fit the model.",
+                    200000,
+                    ParamValidators.gtEq(2));
+
+    default String getStrategy() {
+        return get(STRATEGY);
+    }
+
+    default T setStrategy(String value) {
+        return set(STRATEGY, value);
+    }
+
+    default int getNumBins() {
+        return get(NUM_BINS);
+    }
+
+    default T setNumBins(int value) {
+        return set(NUM_BINS, value);
+    }
+
+    default int getSubSamples() {
+        return get(SUB_SAMPLES);
+    }
+
+    default T setSubSamples(Integer value) {
+        return set(SUB_SAMPLES, value);
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
index 4e59851..28545e9 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
@@ -119,7 +119,7 @@ public class MinMaxScaler
      * A stream operator to compute the min and max values in each partition of the input bounded
      * data stream.
      */
-    private static class MinMaxReduceFunctionOperator extends AbstractStreamOperator<DenseVector>
+    public static class MinMaxReduceFunctionOperator extends AbstractStreamOperator<DenseVector>
             implements OneInputStreamOperator<DenseVector, DenseVector>, BoundedOneInput {
         private ListState<DenseVector> minState;
         private ListState<DenseVector> maxState;
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java
new file mode 100644
index 0000000..17ad6b5
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer;
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel;
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModelData;
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.table.api.Expressions.$;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests {@link KBinsDiscretizer} and {@link KBinsDiscretizerModel}. */
+public class KBinsDiscretizerTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainTable;
+    private Table testTable;
+
+    // Column0 for normal cases, column1 for constant cases, column2 for numDistinct < numBins
+    // cases.
+    private static final List<Row> TRAIN_INPUT =
+            Arrays.asList(
+                    Row.of(Vectors.dense(1, 10, 0)),
+                    Row.of(Vectors.dense(1, 10, 0)),
+                    Row.of(Vectors.dense(1, 10, 0)),
+                    Row.of(Vectors.dense(4, 10, 0)),
+                    Row.of(Vectors.dense(5, 10, 0)),
+                    Row.of(Vectors.dense(6, 10, 0)),
+                    Row.of(Vectors.dense(7, 10, 0)),
+                    Row.of(Vectors.dense(10, 10, 0)),
+                    Row.of(Vectors.dense(13, 10, 3)));
+
+    private static final List<Row> TEST_INPUT =
+            Arrays.asList(
+                    Row.of(Vectors.dense(-1, 0, 0)),
+                    Row.of(Vectors.dense(1, 1, 1)),
+                    Row.of(Vectors.dense(1.5, 1, 2)),
+                    Row.of(Vectors.dense(5, 2, 3)),
+                    Row.of(Vectors.dense(7.25, 3, 4)),
+                    Row.of(Vectors.dense(13, 4, 5)),
+                    Row.of(Vectors.dense(15, 4, 6)));
+
+    private static final double[][] UNIFORM_MODEL_DATA =
+            new double[][] {
+                new double[] {1, 5, 9, 13},
+                new double[] {Double.MIN_VALUE, Double.MAX_VALUE},
+                new double[] {0, 1, 2, 3}
+            };
+
+    private static final List<Row> UNIFORM_OUTPUT =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0, 0, 0)),
+                    Row.of(Vectors.dense(0, 0, 1)),
+                    Row.of(Vectors.dense(0, 0, 2)),
+                    Row.of(Vectors.dense(1, 0, 2)),
+                    Row.of(Vectors.dense(1, 0, 2)),
+                    Row.of(Vectors.dense(2, 0, 2)),
+                    Row.of(Vectors.dense(2, 0, 2)));
+
+    private static final List<Row> QUANTILE_OUTPUT =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0, 0, 0)),
+                    Row.of(Vectors.dense(0, 0, 0)),
+                    Row.of(Vectors.dense(0, 0, 0)),
+                    Row.of(Vectors.dense(1, 0, 0)),
+                    Row.of(Vectors.dense(2, 0, 0)),
+                    Row.of(Vectors.dense(2, 0, 0)),
+                    Row.of(Vectors.dense(2, 0, 0)));
+
+    private static final List<Row> KMEANS_OUTPUT =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0, 0, 0)),
+                    Row.of(Vectors.dense(0, 0, 1)),
+                    Row.of(Vectors.dense(0, 0, 2)),
+                    Row.of(Vectors.dense(1, 0, 2)),
+                    Row.of(Vectors.dense(1, 0, 2)),
+                    Row.of(Vectors.dense(2, 0, 2)),
+                    Row.of(Vectors.dense(2, 0, 2)));
+
+    private static final double TOLERANCE = 1e-7;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        trainTable = tEnv.fromDataStream(env.fromCollection(TRAIN_INPUT)).as("input");
+        testTable = tEnv.fromDataStream(env.fromCollection(TEST_INPUT)).as("input");
+    }
+
+    @SuppressWarnings("unchecked, ConstantConditions")
+    private void verifyPredictionResult(
+            List<Row> expectedOutput, Table output, String predictionCol) throws Exception {
+        List<Row> collectedResult =
+                IteratorUtils.toList(
+                        tEnv.toDataStream(output.select($(predictionCol))).executeAndCollect());
+        compareResultCollections(
+                expectedOutput,
+                collectedResult,
+                (o1, o2) ->
+                        TestUtils.compare(
+                                (DenseVector) o1.getField(0), (DenseVector) o2.getField(0)));
+    }
+
+    @Test
+    public void testParam() {
+        KBinsDiscretizer kBinsDiscretizer = new KBinsDiscretizer();
+
+        assertEquals("input", kBinsDiscretizer.getInputCol());
+        assertEquals(5, kBinsDiscretizer.getNumBins());
+        assertEquals("quantile", kBinsDiscretizer.getStrategy());
+        assertEquals(200000, kBinsDiscretizer.getSubSamples());
+        assertEquals("output", kBinsDiscretizer.getOutputCol());
+
+        kBinsDiscretizer
+                .setInputCol("test_input")
+                .setNumBins(10)
+                .setStrategy(KBinsDiscretizerParams.KMEANS)
+                .setSubSamples(1000)
+                .setOutputCol("test_output");
+
+        assertEquals("test_input", kBinsDiscretizer.getInputCol());
+        assertEquals(10, kBinsDiscretizer.getNumBins());
+        assertEquals("kmeans", kBinsDiscretizer.getStrategy());
+        assertEquals(1000, kBinsDiscretizer.getSubSamples());
+        assertEquals("test_output", kBinsDiscretizer.getOutputCol());
+    }
+
+    @Test
+    public void testOutputSchema() {
+        Table tempTable =
+                tEnv.fromDataStream(env.fromElements(Row.of("", "")))
+                        .as("test_input", "dummy_input");
+        KBinsDiscretizer kBinsDiscretizer =
+                new KBinsDiscretizer().setInputCol("test_input").setOutputCol("test_output");
+        Table output = kBinsDiscretizer.fit(tempTable).transform(tempTable)[0];
+
+        assertEquals(
+                Arrays.asList("test_input", "dummy_input", "test_output"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        KBinsDiscretizer kBinsDiscretizer = new KBinsDiscretizer().setNumBins(3);
+        Table output;
+
+        // Tests uniform strategy.
+        kBinsDiscretizer.setStrategy(KBinsDiscretizerParams.UNIFORM);
+        output = kBinsDiscretizer.fit(trainTable).transform(testTable)[0];
+        verifyPredictionResult(UNIFORM_OUTPUT, output, kBinsDiscretizer.getOutputCol());
+
+        // Tests quantile strategy.
+        kBinsDiscretizer.setStrategy(KBinsDiscretizerParams.QUANTILE);
+        output = kBinsDiscretizer.fit(trainTable).transform(testTable)[0];
+        verifyPredictionResult(QUANTILE_OUTPUT, output, kBinsDiscretizer.getOutputCol());
+
+        // Tests kmeans strategy.
+        kBinsDiscretizer.setStrategy(KBinsDiscretizerParams.KMEANS);
+        output = kBinsDiscretizer.fit(trainTable).transform(testTable)[0];
+        verifyPredictionResult(KMEANS_OUTPUT, output, kBinsDiscretizer.getOutputCol());
+    }
+
+    @Test
+    public void testSaveLoadAndPredict() throws Exception {
+        KBinsDiscretizer kBinsDiscretizer =
+                new KBinsDiscretizer().setNumBins(3).setStrategy(KBinsDiscretizerParams.UNIFORM);
+        kBinsDiscretizer =
+                TestUtils.saveAndReload(
+                        tEnv, kBinsDiscretizer, tempFolder.newFolder().getAbsolutePath());
+
+        KBinsDiscretizerModel model = kBinsDiscretizer.fit(trainTable);
+        model = TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
+
+        assertEquals(
+                Collections.singletonList("binEdges"),
+                model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+        Table output = model.transform(testTable)[0];
+        verifyPredictionResult(UNIFORM_OUTPUT, output, kBinsDiscretizer.getOutputCol());
+    }
+
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testGetModelData() throws Exception {
+        KBinsDiscretizer kBinsDiscretizer =
+                new KBinsDiscretizer().setNumBins(3).setStrategy(KBinsDiscretizerParams.UNIFORM);
+        KBinsDiscretizerModel model = kBinsDiscretizer.fit(trainTable);
+        Table modelDataTable = model.getModelData()[0];
+
+        assertEquals(
+                Collections.singletonList("binEdges"),
+                modelDataTable.getResolvedSchema().getColumnNames());
+
+        List<KBinsDiscretizerModelData> collectedModelData =
+                (List<KBinsDiscretizerModelData>)
+                        IteratorUtils.toList(
+                                KBinsDiscretizerModelData.getModelDataStream(modelDataTable)
+                                        .executeAndCollect());
+        assertEquals(1, collectedModelData.size());
+
+        KBinsDiscretizerModelData modelData = collectedModelData.get(0);
+        assertEquals(UNIFORM_MODEL_DATA.length, modelData.binEdges.length);
+        for (int i = 0; i < modelData.binEdges.length; i++) {
+            assertArrayEquals(UNIFORM_MODEL_DATA[i], modelData.binEdges[i], TOLERANCE);
+        }
+    }
+
+    @Test
+    public void testSetModelData() throws Exception {
+        KBinsDiscretizer kBinsDiscretizer =
+                new KBinsDiscretizer().setNumBins(3).setStrategy(KBinsDiscretizerParams.UNIFORM);
+
+        KBinsDiscretizerModel model = kBinsDiscretizer.fit(trainTable);
+
+        KBinsDiscretizerModel newModel = new KBinsDiscretizerModel();
+        ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+        newModel.setModelData(model.getModelData());
+        Table output = newModel.transform(testTable)[0];
+
+        verifyPredictionResult(UNIFORM_OUTPUT, output, kBinsDiscretizer.getOutputCol());
+    }
+
+    @Test
+    public void testFitOnEmptyData() {
+        Table emptyTable =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_INPUT).filter(x -> x.getArity() == 0))
+                        .as("input");
+        KBinsDiscretizerModel model = new KBinsDiscretizer().fit(emptyTable);
+        Table modelDataTable = model.getModelData()[0];
+        try {
+            modelDataTable.execute().collect().next();
+            fail();
+        } catch (Throwable e) {
+            assertEquals("The training set is empty.", ExceptionUtils.getRootCause(e).getMessage());
+        }
+    }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/kbinsdiscreteizer_example.py b/flink-ml-python/pyflink/examples/ml/feature/kbinsdiscreteizer_example.py
new file mode 100644
index 0000000..d33d613
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/kbinsdiscreteizer_example.py
@@ -0,0 +1,75 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that trains a StringIndexer model and uses it for feature
+# engineering.
+#
+# Before executing this program, please make sure you have followed Flink ML's
+# quick start guideline to set up Flink ML and Flink environment. The guideline
+# can be found at
+#
+# https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
+
+from pyflink.common import Types
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.kbinsdiscretizer import KBinsDiscretizer
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input for training and prediction.
+input_table = t_env.from_data_stream(
+    env.from_collection([
+        (Vectors.dense(1, 10, 0),),
+        (Vectors.dense(1, 10, 0),),
+        (Vectors.dense(1, 10, 0),),
+        (Vectors.dense(4, 10, 0),),
+        (Vectors.dense(5, 10, 0),),
+        (Vectors.dense(6, 10, 0),),
+        (Vectors.dense(7, 10, 0),),
+        (Vectors.dense(10, 10, 0),),
+        (Vectors.dense(13, 10, 0),),
+    ],
+        type_info=Types.ROW_NAMED(
+            ['input', ],
+            [DenseVectorTypeInfo(), ])))
+
+# Creates a KBinsDiscretizer object and initializes its parameters.
+k_bins_discretizer = KBinsDiscretizer() \
+    .set_input_col('input') \
+    .set_output_col('output') \
+    .set_num_bins(3) \
+    .set_strategy('uniform')
+
+# Trains the KBinsDiscretizer Model.
+model = k_bins_discretizer.fit(input_table)
+
+# Uses the KBinsDiscretizer Model for predictions.
+output = model.transform(input_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+    print('Input Value: ' + str(result[field_names.index(k_bins_discretizer.get_input_col())])
+          + '\tOutput Value: ' +
+          str(result[field_names.index(k_bins_discretizer.get_output_col())]))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/kbinsdiscretizer.py b/flink-ml-python/pyflink/ml/lib/feature/kbinsdiscretizer.py
new file mode 100644
index 0000000..03ab5e7
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/kbinsdiscretizer.py
@@ -0,0 +1,168 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+
+from pyflink.ml.core.param import IntParam, StringParam, ParamValidators
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator
+from pyflink.ml.lib.param import HasInputCol, HasOutputCol
+
+
+class _KBinsDiscretizerModelParams(
+    JavaWithParams,
+    HasInputCol,
+    HasOutputCol
+):
+    """
+    Params for :class:`KBinsDiscretizerModel`.
+    """
+
+    def __init__(self, java_params):
+        super(_KBinsDiscretizerModelParams, self).__init__(java_params)
+
+
+class _KBinsDiscretizerParams(_KBinsDiscretizerModelParams):
+    """
+    Params for :class:`KBinsDiscretizer`.
+    """
+
+    """
+    Supported options to define the widths of the bins are listed as follows.
+    <ul>
+        <li>uniform: all bins in each feature have identical widths.
+        <li>quantile: all bins in each feature have the same number of points.
+        <li>kmeans: values in each bin have the same nearest center of a 1D kmeans cluster.
+    </ul>
+    """
+    STRATEGY: StringParam = StringParam(
+        "strategy",
+        "Strategy used to define the width of the bin.",
+        'quantile',
+        ParamValidators.in_array(['uniform', 'quantile', 'kmeans']))
+
+    NUM_BINS: IntParam = IntParam(
+        "num_bins",
+        "Number of bins to produce.",
+        5,
+        ParamValidators.gt_eq(2)
+    )
+
+    SUB_SAMPLES: IntParam = IntParam(
+        "sub_samples",
+        "Maximum number of samples used to fit the model.",
+        200000,
+        ParamValidators.gt_eq(2)
+    )
+
+    def __init__(self, java_params):
+        super(_KBinsDiscretizerParams, self).__init__(java_params)
+
+    def set_strategy(self, value: str):
+        return typing.cast(_KBinsDiscretizerParams, self.set(self.STRATEGY, value))
+
+    def get_strategy(self) -> str:
+        return self.get(self.STRATEGY)
+
+    def set_num_bins(self, value: int):
+        return typing.cast(_KBinsDiscretizerParams, self.set(self.NUM_BINS, value))
+
+    def get_num_bins(self) -> int:
+        return self.get(self.NUM_BINS)
+
+    def set_sub_samples(self, value: int):
+        return typing.cast(_KBinsDiscretizerParams, self.set(self.SUB_SAMPLES, value))
+
+    def get_sub_samples(self) -> int:
+        return self.get(self.SUB_SAMPLES)
+
+    @property
+    def strategy(self):
+        return self.get_strategy()
+
+    @property
+    def num_bins(self):
+        return self.get_num_bins()
+
+    @property
+    def sub_samples(self):
+        return self.get_sub_samples()
+
+
+class KBinsDiscretizerModel(JavaFeatureModel, _KBinsDiscretizerModelParams):
+    """
+    A Model which transforms continuous features into discrete features using the model data
+    computed by :class::KBinsDiscretizer.
+
+    <p>A feature value `v` should be mapped to a bin with edges as `{left, right}` if `v` is
+    in `[left, right)`. If `v` does not fall into any of the bins, it is mapped to the
+    closest bin. For example uppose the bin edges are `{-1, 0, 1}` for one column, then
+    we have two bins `{-1, 0}` and `{0, 1}`. In this case, -2 is mapped into 0-th bin,
+    0 is mapped into the 1-st bin and 2 is mapped into the 1-st bin.
+    """
+
+    def __init__(self, java_model=None):
+        super(KBinsDiscretizerModel, self).__init__(java_model)
+
+    @classmethod
+    def _java_model_package_name(cls) -> str:
+        return "kbinsdiscretizer"
+
+    @classmethod
+    def _java_model_class_name(cls) -> str:
+        return "KBinsDiscretizerModel"
+
+
+class KBinsDiscretizer(JavaFeatureEstimator, _KBinsDiscretizerParams):
+    """
+    An Estimator which implements discretization (also known as quantization or binning) to
+    transform continuous features into discrete ones. The output values are in [0, numBins).
+
+    <p>KBinsDiscretizer implements three different binning strategies, and it can be set by {@link
+    KBinsDiscretizerParams#STRATEGY}. If the strategy is set as
+    {@link KBinsDiscretizerParams#KMEANS} or {@link KBinsDiscretizerParams#QUANTILE},
+    users should further set {@link KBinsDiscretizerParams#SUB_SAMPLES} for
+    better performance.
+
+    <p>There are several corner cases for different inputs as listed below:
+
+    <ul>
+        <li>When the input values of one column are all the same, then they should be mapped
+        to the same bin (i.e., the zero-th bin). Thus the corresponding bin edges are
+        {Double.MIN_VALUE, Double.MAX_VALUE}.
+        <li>When the number of distinct values of one column is less than the specified
+        number of bins and the {@link KBinsDiscretizerParams#STRATEGY} is set as {@link
+        KBinsDiscretizerParams#KMEANS}, we switch to {@link KBinsDiscretizerParams#UNIFORM}.
+        <li>When the width of one output bin is zero, i.e., the left edge equals to the right
+        edge of the bin, we remove it.
+    </ul>
+    """
+
+    def __init__(self):
+        super(KBinsDiscretizer, self).__init__()
+
+    @classmethod
+    def _create_model(cls, java_model) -> KBinsDiscretizerModel:
+        return KBinsDiscretizerModel(java_model)
+
+    @classmethod
+    def _java_estimator_package_name(cls) -> str:
+        return "kbinsdiscretizer"
+
+    @classmethod
+    def _java_estimator_class_name(cls) -> str:
+        return "KBinsDiscretizer"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py
new file mode 100644
index 0000000..2dacb68
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py
@@ -0,0 +1,172 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.feature.kbinsdiscretizer import KBinsDiscretizer, KBinsDiscretizerModel
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class KBinsDiscretizerTest(PyFlinkMLTestCase):
+    def setUp(self):
+        super(KBinsDiscretizerTest, self).setUp()
+        self.train_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (Vectors.dense(1, 10, 0),),
+                (Vectors.dense(1, 10, 0),),
+                (Vectors.dense(1, 10, 0),),
+                (Vectors.dense(4, 10, 0),),
+                (Vectors.dense(5, 10, 0),),
+                (Vectors.dense(6, 10, 0),),
+                (Vectors.dense(7, 10, 0),),
+                (Vectors.dense(10, 10, 0),),
+                (Vectors.dense(13, 10, 3),),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['input', ],
+                    [DenseVectorTypeInfo(), ])))
+
+        self.predict_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (Vectors.dense(-1, 0, 0),),
+                (Vectors.dense(1, 1, 1),),
+                (Vectors.dense(1.5, 1, 2),),
+                (Vectors.dense(5, 2, 3),),
+                (Vectors.dense(7.25, 3, 4),),
+                (Vectors.dense(13, 4, 5),),
+                (Vectors.dense(15, 4, 6),),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['input', ],
+                    [DenseVectorTypeInfo(), ])))
+
+        self.uniform_output = [
+            Vectors.dense(0, 0, 0),
+            Vectors.dense(0, 0, 1),
+            Vectors.dense(0, 0, 2),
+            Vectors.dense(1, 0, 2),
+            Vectors.dense(1, 0, 2),
+            Vectors.dense(2, 0, 2),
+            Vectors.dense(2, 0, 2),
+        ]
+
+        self.quantile_output = [
+            Vectors.dense(0, 0, 0),
+            Vectors.dense(0, 0, 0),
+            Vectors.dense(0, 0, 0),
+            Vectors.dense(1, 0, 0),
+            Vectors.dense(2, 0, 0),
+            Vectors.dense(2, 0, 0),
+            Vectors.dense(2, 0, 0),
+        ]
+
+        self.kmeans_output = [
+            Vectors.dense(0, 0, 0),
+            Vectors.dense(0, 0, 1),
+            Vectors.dense(0, 0, 2),
+            Vectors.dense(1, 0, 2),
+            Vectors.dense(1, 0, 2),
+            Vectors.dense(2, 0, 2),
+            Vectors.dense(2, 0, 2),
+        ]
+
+    def test_param(self):
+        k_bins_discretizer = KBinsDiscretizer()
+
+        self.assertEqual("input", k_bins_discretizer.input_col)
+        self.assertEqual(5, k_bins_discretizer.num_bins)
+        self.assertEqual("quantile", k_bins_discretizer.strategy)
+        self.assertEqual(200000, k_bins_discretizer.sub_samples)
+        self.assertEqual("output", k_bins_discretizer.output_col)
+
+        k_bins_discretizer \
+            .set_input_col("test_input") \
+            .set_num_bins(10) \
+            .set_strategy('kmeans') \
+            .set_sub_samples(1000) \
+            .set_output_col("test_output")
+
+        self.assertEqual("test_input", k_bins_discretizer.input_col)
+        self.assertEqual(10, k_bins_discretizer.num_bins)
+        self.assertEqual("kmeans", k_bins_discretizer.strategy)
+        self.assertEqual(1000, k_bins_discretizer.sub_samples)
+        self.assertEqual("test_output", k_bins_discretizer.output_col)
+
+    def test_output_schema(self):
+        k_bins_discretizer = KBinsDiscretizer() \
+            .set_input_col("test_input") \
+            .set_output_col("test_output")
+        input_data_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                ('', ''),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['test_input', 'dummy_input'],
+                    [Types.STRING(), Types.STRING()])))
+        output = k_bins_discretizer \
+            .fit(input_data_table) \
+            .transform(input_data_table)[0]
+
+        self.assertEqual(
+            [k_bins_discretizer.input_col, 'dummy_input', k_bins_discretizer.output_col],
+            output.get_schema().get_field_names())
+
+    def verify_prediction_result(self, expected, output_table):
+        predicted_results = [result[1] for result in
+                             self.t_env.to_data_stream(output_table).execute_and_collect()]
+
+        predicted_results.sort(key=lambda x: (x[0], x[1], x[2]))
+        expected.sort(key=lambda x: (x[0], x[1], x[2]))
+
+        self.assertEqual(expected, predicted_results)
+
+    def test_fit_and_predict(self):
+        k_bins_discretizer = KBinsDiscretizer().set_num_bins(3)
+
+        # Tests uniform strategy.
+        k_bins_discretizer.set_strategy('uniform')
+        output = k_bins_discretizer.fit(self.train_table).transform(self.predict_table)[0]
+        self.verify_prediction_result(self.uniform_output, output)
+
+        # Tests quantile strategy.
+        k_bins_discretizer.set_strategy('quantile')
+        output = k_bins_discretizer.fit(self.train_table).transform(self.predict_table)[0]
+        self.verify_prediction_result(self.quantile_output, output)
+
+        # Tests kmeans strategy.
+        k_bins_discretizer.set_strategy('kmeans')
+        output = k_bins_discretizer.fit(self.train_table).transform(self.predict_table)[0]
+        self.verify_prediction_result(self.kmeans_output, output)
+
+    def test_save_load_predict(self):
+        k_bins_discretizer = KBinsDiscretizer().set_num_bins(3)
+        estimator_path = os.path.join(self.temp_dir, 'test_save_load_predict_kbinsdiscretizer')
+        k_bins_discretizer.save(estimator_path)
+        k_bins_discretizer = KBinsDiscretizer.load(self.t_env, estimator_path)
+
+        model = k_bins_discretizer.fit(self.train_table)
+        model_path = os.path.join(self.temp_dir, 'test_save_load_predict_kbinsdiscretizer_model')
+        model.save(model_path)
+        self.env.execute('save_model')
+        model = KBinsDiscretizerModel.load(self.t_env, model_path)
+
+        output = model.transform(self.predict_table)[0]
+        self.verify_prediction_result(self.quantile_output, output)