You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/08/17 06:28:18 UTC
[flink-ml] branch master updated: [FLINK-28803] Add Transformer and Estimator for KBinsDiscretizer
This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new be546ad [FLINK-28803] Add Transformer and Estimator for KBinsDiscretizer
be546ad is described below
commit be546ad47887bdf2bbcf66b78d33bbbc5cdf643b
Author: Zhipeng Zhang <zh...@gmail.com>
AuthorDate: Wed Aug 17 14:28:14 2022 +0800
[FLINK-28803] Add Transformer and Estimator for KBinsDiscretizer
This closes #139.
---
.../examples/feature/KBinsDiscretizerExample.java | 71 +++++
.../feature/kbinsdiscretizer/KBinsDiscretizer.java | 341 +++++++++++++++++++++
.../kbinsdiscretizer/KBinsDiscretizerModel.java | 172 +++++++++++
.../KBinsDiscretizerModelData.java | 128 ++++++++
.../KBinsDiscretizerModelParams.java | 29 ++
.../kbinsdiscretizer/KBinsDiscretizerParams.java | 85 +++++
.../ml/feature/minmaxscaler/MinMaxScaler.java | 2 +-
.../flink/ml/feature/KBinsDiscretizerTest.java | 285 +++++++++++++++++
.../ml/feature/kbinsdiscreteizer_example.py | 75 +++++
.../pyflink/ml/lib/feature/kbinsdiscretizer.py | 168 ++++++++++
.../ml/lib/feature/tests/test_kbinsdiscretizer.py | 172 +++++++++++
11 files changed, 1527 insertions(+), 1 deletion(-)
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java
new file mode 100644
index 0000000..1478f8c
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/KBinsDiscretizerExample.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer;
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel;
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that trains a KBinsDiscretizer model and uses it for feature engineering. */
+public class KBinsDiscretizerExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input data.
+ DataStream<Row> inputStream =
+ env.fromElements(
+ Row.of(Vectors.dense(1, 10, 0)),
+ Row.of(Vectors.dense(1, 10, 0)),
+ Row.of(Vectors.dense(1, 10, 0)),
+ Row.of(Vectors.dense(4, 10, 0)),
+ Row.of(Vectors.dense(5, 10, 0)),
+ Row.of(Vectors.dense(6, 10, 0)),
+ Row.of(Vectors.dense(7, 10, 0)),
+ Row.of(Vectors.dense(10, 10, 0)),
+ Row.of(Vectors.dense(13, 10, 3)));
+ Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+ // Creates a KBinsDiscretizer object and initializes its parameters.
+ KBinsDiscretizer kBinsDiscretizer =
+ new KBinsDiscretizer().setNumBins(3).setStrategy(KBinsDiscretizerParams.UNIFORM);
+
+ // Trains the KBinsDiscretizer Model.
+ KBinsDiscretizerModel model = kBinsDiscretizer.fit(inputTable);
+
+ // Uses the KBinsDiscretizer Model for predictions.
+ Table outputTable = model.transform(inputTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+ DenseVector inputValue = (DenseVector) row.getField(kBinsDiscretizer.getInputCol());
+ DenseVector outputValue = (DenseVector) row.getField(kBinsDiscretizer.getOutputCol());
+ System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java
new file mode 100644
index 0000000..accae29
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizer.java
@@ -0,0 +1,341 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.kbinsdiscretizer;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler.MinMaxReduceFunctionOperator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * An Estimator which implements discretization (also known as quantization or binning) to transform
+ * continuous features into discrete ones. The output values are in [0, numBins).
+ *
+ * <p>KBinsDiscretizer implements three different binning strategies, and it can be set by {@link
+ * KBinsDiscretizerParams#STRATEGY}. If the strategy is set as {@link KBinsDiscretizerParams#KMEANS}
+ * or {@link KBinsDiscretizerParams#QUANTILE}, users should further set {@link
+ * KBinsDiscretizerParams#SUB_SAMPLES} for better performance.
+ *
+ * <p>There are several corner cases for different inputs as listed below:
+ *
+ * <ul>
+ * <li>When the input values of one column are all the same, then they should be mapped to the
+ * same bin (i.e., the zero-th bin). Thus the corresponding bin edges are `{Double.MIN_VALUE,
+ * Double.MAX_VALUE}`.
+ * <li>When the number of distinct values of one column is less than the specified number of bins
+ * and the {@link KBinsDiscretizerParams#STRATEGY} is set as {@link
+ * KBinsDiscretizerParams#KMEANS}, we switch to {@link KBinsDiscretizerParams#UNIFORM}.
+ * <li>When the width of one output bin is zero, i.e., the left edge equals to the right edge of
+ * the bin, we remove it.
+ * </ul>
+ */
+public class KBinsDiscretizer
+ implements Estimator<KBinsDiscretizer, KBinsDiscretizerModel>,
+ KBinsDiscretizerParams<KBinsDiscretizer> {
+ private static final Logger LOG = LoggerFactory.getLogger(KBinsDiscretizer.class);
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public KBinsDiscretizer() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public KBinsDiscretizerModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ String inputCol = getInputCol();
+ String strategy = getStrategy();
+ int numBins = getNumBins();
+
+ DataStream<DenseVector> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ (MapFunction<Row, DenseVector>)
+ value -> ((Vector) value.getField(inputCol)).toDense());
+
+ DataStream<DenseVector> preprocessedData;
+ if (strategy.equals(UNIFORM)) {
+ preprocessedData =
+ inputData
+ .transform(
+ "reduceInEachPartition",
+ inputData.getType(),
+ new MinMaxReduceFunctionOperator())
+ .transform(
+ "reduceInFinalPartition",
+ inputData.getType(),
+ new MinMaxReduceFunctionOperator())
+ .setParallelism(1);
+ } else {
+ preprocessedData =
+ DataStreamUtils.sample(
+ inputData, getSubSamples(), getClass().getName().hashCode());
+ }
+
+ DataStream<KBinsDiscretizerModelData> modelData =
+ DataStreamUtils.mapPartition(
+ preprocessedData,
+ new MapPartitionFunction<DenseVector, KBinsDiscretizerModelData>() {
+ @Override
+ public void mapPartition(
+ Iterable<DenseVector> iterable,
+ Collector<KBinsDiscretizerModelData> collector) {
+ List<DenseVector> list = new ArrayList<>();
+ iterable.iterator().forEachRemaining(list::add);
+
+ if (list.size() == 0) {
+ throw new RuntimeException("The training set is empty.");
+ }
+
+ double[][] binEdges;
+ switch (strategy) {
+ case UNIFORM:
+ binEdges = findBinEdgesWithUniformStrategy(list, numBins);
+ break;
+ case QUANTILE:
+ binEdges = findBinEdgesWithQuantileStrategy(list, numBins);
+ break;
+ case KMEANS:
+ binEdges = findBinEdgesWithKMeansStrategy(list, numBins);
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ "Unsupported "
+ + STRATEGY
+ + " type: "
+ + strategy
+ + ".");
+ }
+
+ collector.collect(new KBinsDiscretizerModelData(binEdges));
+ }
+ });
+ modelData.getTransformation().setParallelism(1);
+
+ KBinsDiscretizerModel model =
+ new KBinsDiscretizerModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static KBinsDiscretizer load(StreamTableEnvironment tEnv, String path)
+ throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ private static double[][] findBinEdgesWithUniformStrategy(
+ List<DenseVector> input, int numBins) {
+ DenseVector minVector = input.get(0);
+ DenseVector maxVector = input.get(1);
+ int numColumns = minVector.size();
+ double[][] binEdges = new double[numColumns][];
+
+ for (int columnId = 0; columnId < numColumns; columnId++) {
+ double min = minVector.get(columnId);
+ double max = maxVector.get(columnId);
+ if (min == max) {
+ LOG.warn("Feature " + columnId + " is constant and the output will all be zero.");
+ binEdges[columnId] = new double[] {Double.MIN_VALUE, Double.MAX_VALUE};
+ } else {
+ double width = (max - min) / numBins;
+ binEdges[columnId] = new double[numBins + 1];
+ binEdges[columnId][0] = min;
+ for (int edgeId = 1; edgeId < numBins + 1; edgeId++) {
+ binEdges[columnId][edgeId] = binEdges[columnId][edgeId - 1] + width;
+ }
+ }
+ }
+ return binEdges;
+ }
+
+ private static double[][] findBinEdgesWithQuantileStrategy(
+ List<DenseVector> input, int numBins) {
+ int numColumns = input.get(0).size();
+ int numData = input.size();
+ double[][] binEdges = new double[numColumns][];
+ double[] features = new double[numData];
+
+ for (int columnId = 0; columnId < numColumns; columnId++) {
+ for (int i = 0; i < numData; i++) {
+ features[i] = input.get(i).get(columnId);
+ }
+ Arrays.sort(features);
+
+ if (features[0] == features[numData - 1]) {
+ LOG.warn("Feature " + columnId + " is constant and the output will all be zero.");
+ binEdges[columnId] = new double[] {Double.MIN_VALUE, Double.MAX_VALUE};
+ } else {
+ double width = 1.0 * features.length / numBins;
+ double[] tempBinEdges = new double[numBins + 1];
+
+ for (int binEdgeId = 0; binEdgeId < numBins; binEdgeId++) {
+ tempBinEdges[binEdgeId] = features[(int) (binEdgeId * width)];
+ }
+ tempBinEdges[numBins] = features[numData - 1];
+
+ // Removes bins that are empty, i.e., the left edge equals to the right edge.
+ Set<Double> edges = new HashSet<>(numBins);
+ for (double edge : tempBinEdges) {
+ edges.add(edge);
+ }
+
+ binEdges[columnId] = edges.stream().mapToDouble(Double::doubleValue).toArray();
+ Arrays.sort(binEdges[columnId]);
+ }
+ }
+ return binEdges;
+ }
+
+ private static double[][] findBinEdgesWithKMeansStrategy(List<DenseVector> input, int numBins) {
+ int numColumns = input.get(0).size();
+ int numData = input.size();
+ double[][] binEdges = new double[numColumns][numBins + 1];
+ double[] features = new double[numData];
+
+ double[] kMeansCentroids = new double[numBins];
+ double[] sumByCluster = new double[numBins];
+
+ for (int columnId = 0; columnId < numColumns; columnId++) {
+ for (int i = 0; i < numData; i++) {
+ features[i] = input.get(i).get(columnId);
+ }
+ Arrays.sort(features);
+
+ if (features[0] == features[numData - 1]) {
+ LOG.warn("Feature " + columnId + " is constant and the output will all be zero.");
+ binEdges[columnId] = new double[] {Double.MIN_VALUE, Double.MAX_VALUE};
+ } else {
+ // Checks whether there are more than {numBins} distinct feature values in each
+ // column.
+ // If the number of distinct values is less than {numBins + 1}, then we do not need
+ // to conduct KMeans. Instead, we switch to using {@link
+ // KBinsDiscretizerParams#UNIFORM} for binning.
+ Set<Double> distinctFeatureValues = new HashSet<>(numBins + 1);
+ for (double feature : features) {
+ distinctFeatureValues.add(feature);
+ if (distinctFeatureValues.size() >= numBins + 1) {
+ break;
+ }
+ }
+ if (distinctFeatureValues.size() <= numBins) {
+ double min = features[0];
+ double max = features[features.length - 1];
+ double width = (max - min) / numBins;
+ binEdges[columnId] = new double[numBins + 1];
+ binEdges[columnId][0] = min;
+ for (int edgeId = 1; edgeId < numBins + 1; edgeId++) {
+ binEdges[columnId][edgeId] = binEdges[columnId][edgeId - 1] + width;
+ }
+ continue;
+ } else {
+ // Conducts KMeans here.
+ double width = 1.0 * features.length / numBins;
+ for (int clusterId = 0; clusterId < numBins; clusterId++) {
+ kMeansCentroids[clusterId] = features[(int) (clusterId * width)];
+ }
+
+ // Default values for KMeans.
+ final double tolerance = 1e-4;
+ final int maxIterations = 300;
+
+ double oldLoss = Double.MAX_VALUE;
+ double relativeLoss = Double.MAX_VALUE;
+ int iter = 0;
+ int[] countByCluster = new int[numBins];
+ while (iter < maxIterations && relativeLoss > tolerance) {
+ double loss = 0;
+ for (double featureValue : features) {
+ double minDistance = Math.abs(kMeansCentroids[0] - featureValue);
+ int clusterId = 0;
+ for (int i = 1; i < kMeansCentroids.length; i++) {
+ double distance = Math.abs(kMeansCentroids[i] - featureValue);
+ if (distance < minDistance) {
+ minDistance = distance;
+ clusterId = i;
+ }
+ }
+ countByCluster[clusterId]++;
+ sumByCluster[clusterId] += featureValue;
+ loss += minDistance;
+ }
+
+ // Updates cluster.
+ for (int clusterId = 0; clusterId < kMeansCentroids.length; clusterId++) {
+ kMeansCentroids[clusterId] =
+ sumByCluster[clusterId] / countByCluster[clusterId];
+ }
+ loss /= features.length;
+ relativeLoss = Math.abs(loss - oldLoss);
+ oldLoss = loss;
+ iter++;
+ Arrays.fill(sumByCluster, 0);
+ Arrays.fill(countByCluster, 0);
+ }
+
+ Arrays.sort(kMeansCentroids);
+ binEdges[columnId] = new double[numBins + 1];
+ binEdges[columnId][0] = features[0];
+ binEdges[columnId][numBins] = features[features.length - 1];
+ for (int binEdgeId = 1; binEdgeId < numBins; binEdgeId++) {
+ binEdges[columnId][binEdgeId] =
+ (kMeansCentroids[binEdgeId - 1] + kMeansCentroids[binEdgeId]) / 2;
+ }
+ }
+ }
+ }
+ return binEdges;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java
new file mode 100644
index 0000000..7053e31
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModel.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.kbinsdiscretizer;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which transforms continuous features into discrete features using the model data computed
+ * by {@link KBinsDiscretizer}.
+ *
+ * <p>A feature value `v` should be mapped to a bin with edges as `{left, right}` if `v` is in
+ * `[left, right)`. If `v` does not fall into any of the bins, it is mapped to the closest bin. For
+ * example suppose the bin edges are `{-1, 0, 1}` for one column, then we have two bins `{-1, 0}`
+ * and `{0, 1}`. In this case, -2 is mapped into 0-th bin, 0 is mapped into the 1-st bin and 2 is
+ * mapped into the 1-st bin.
+ */
+public class KBinsDiscretizerModel
+ implements Model<KBinsDiscretizerModel>,
+ KBinsDiscretizerModelParams<KBinsDiscretizerModel> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private Table modelDataTable;
+
+ public KBinsDiscretizerModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ DataStream<Row> inputData = tEnv.toDataStream(inputs[0]);
+ DataStream<KBinsDiscretizerModelData> modelData =
+ KBinsDiscretizerModelData.getModelDataStream(modelDataTable);
+
+ final String broadcastModelKey = "broadcastModelKey";
+ RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(
+ inputTypeInfo.getFieldTypes(),
+ TypeInformation.of(DenseVector.class)),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+ DataStream<Row> output =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(inputData),
+ Collections.singletonMap(broadcastModelKey, modelData),
+ inputList -> {
+ DataStream input = inputList.get(0);
+ return input.map(
+ new FindBinFunction(getInputCol(), broadcastModelKey),
+ outputTypeInfo);
+ });
+ return new Table[] {tEnv.fromDataStream(output)};
+ }
+
+ @Override
+ public KBinsDiscretizerModel setModelData(Table... inputs) {
+ modelDataTable = inputs[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ KBinsDiscretizerModelData.getModelDataStream(modelDataTable),
+ path,
+ new KBinsDiscretizerModelData.ModelDataEncoder());
+ }
+
+ public static KBinsDiscretizerModel load(StreamTableEnvironment tEnv, String path)
+ throws IOException {
+ KBinsDiscretizerModel model = ReadWriteUtils.loadStageParam(path);
+ Table modelDataTable =
+ ReadWriteUtils.loadModelData(
+ tEnv, path, new KBinsDiscretizerModelData.ModelDataDecoder());
+ return model.setModelData(modelDataTable);
+ }
+
+ private static class FindBinFunction extends RichMapFunction<Row, Row> {
+ private final String inputCol;
+ private final String broadcastKey;
+ /** Model data used to find bins for each feature. */
+ private double[][] binEdges;
+
+ public FindBinFunction(String inputCol, String broadcastKey) {
+ this.inputCol = inputCol;
+ this.broadcastKey = broadcastKey;
+ }
+
+ @Override
+ public Row map(Row row) {
+ if (binEdges == null) {
+ KBinsDiscretizerModelData modelData =
+ (KBinsDiscretizerModelData)
+ getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+ binEdges = modelData.binEdges;
+ }
+ DenseVector inputVec = ((Vector) row.getField(inputCol)).toDense();
+ DenseVector outputVec = inputVec.clone();
+ for (int i = 0; i < inputVec.size(); i++) {
+ double targetFeature = inputVec.get(i);
+ int index = Arrays.binarySearch(binEdges[i], targetFeature);
+ if (index < 0) {
+ // Computes the index to insert.
+ index = -index - 1;
+ // Puts it in the left bin.
+ index--;
+ }
+ // Handles the boundary.
+ index = Math.min(index, (binEdges[i].length - 2));
+ index = Math.max(index, 0);
+
+ outputVec.set(i, index);
+ }
+ return Row.join(row, Row.of(outputVec));
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModelData.java
new file mode 100644
index 0000000..8dd402a
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModelData.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.kbinsdiscretizer;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.array.DoublePrimitiveArraySerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link KBinsDiscretizerModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to a data stream, and
+ * classes to save/load model data.
+ */
+public class KBinsDiscretizerModelData {
+ /**
+ * The edges of bins for each column, e.g., binEdges[0] is the edges for features at 0-th
+ * dimension.
+ */
+ public double[][] binEdges;
+
+ public KBinsDiscretizerModelData() {}
+
+ public KBinsDiscretizerModelData(double[][] binEdges) {
+ this.binEdges = binEdges;
+ }
+
+ /**
+ * Converts the table model to a data stream.
+ *
+ * @param modelDataTable The table model data.
+ * @return The data stream model data.
+ */
+ public static DataStream<KBinsDiscretizerModelData> getModelDataStream(Table modelDataTable) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+ return tEnv.toDataStream(modelDataTable)
+ .map(x -> new KBinsDiscretizerModelData((double[][]) x.getField(0)));
+ }
+
+ /** Encoder for {@link KBinsDiscretizerModelData}. */
+ public static class ModelDataEncoder implements Encoder<KBinsDiscretizerModelData> {
+
+ @Override
+ public void encode(KBinsDiscretizerModelData modelData, OutputStream outputStream)
+ throws IOException {
+ DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream);
+ IntSerializer.INSTANCE.serialize(modelData.binEdges.length, dataOutputView);
+
+ DoublePrimitiveArraySerializer doubleArraySerializer =
+ DoublePrimitiveArraySerializer.INSTANCE;
+ for (double[] binEdge : modelData.binEdges) {
+ doubleArraySerializer.serialize(binEdge, dataOutputView);
+ }
+ }
+ }
+
+ /** Decoder for {@link KBinsDiscretizerModelData}. */
+ public static class ModelDataDecoder extends SimpleStreamFormat<KBinsDiscretizerModelData> {
+ @Override
+ public Reader<KBinsDiscretizerModelData> createReader(
+ Configuration config, FSDataInputStream stream) {
+ return new Reader<KBinsDiscretizerModelData>() {
+
+ @Override
+ public KBinsDiscretizerModelData read() throws IOException {
+ DataInputView source = new DataInputViewStreamWrapper(stream);
+ try {
+ int numColumns = IntSerializer.INSTANCE.deserialize(source);
+ double[][] binEdges = new double[numColumns][];
+
+ DoublePrimitiveArraySerializer doubleArraySerializer =
+ DoublePrimitiveArraySerializer.INSTANCE;
+ for (int i = 0; i < numColumns; i++) {
+ binEdges[i] = doubleArraySerializer.deserialize(source);
+ }
+
+ return new KBinsDiscretizerModelData(binEdges);
+ } catch (EOFException e) {
+ return null;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ stream.close();
+ }
+ };
+ }
+
+ @Override
+ public TypeInformation<KBinsDiscretizerModelData> getProducedType() {
+ return TypeInformation.of(KBinsDiscretizerModelData.class);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModelParams.java
new file mode 100644
index 0000000..3e0fd3d
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerModelParams.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.kbinsdiscretizer;
+
+import org.apache.flink.ml.common.param.HasInputCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
+
+/**
+ * Params for {@link KBinsDiscretizerModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface KBinsDiscretizerModelParams<T> extends HasInputCol<T>, HasOutputCol<T> {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerParams.java
new file mode 100644
index 0000000..c1f0b00
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/kbinsdiscretizer/KBinsDiscretizerParams.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.kbinsdiscretizer;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params for {@link KBinsDiscretizer}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface KBinsDiscretizerParams<T> extends KBinsDiscretizerModelParams<T> {
+ String UNIFORM = "uniform";
+ String QUANTILE = "quantile";
+ String KMEANS = "kmeans";
+
+ /**
+ * Supported options to define the widths of the bins are listed as follows.
+ *
+ * <ul>
+ * <li>uniform: all bins in each feature have identical widths.
+ * <li>quantile: all bins in each feature have the same number of points.
+ * <li>kmeans: values in each bin have the same nearest center of a 1D kmeans cluster.
+ * </ul>
+ */
+ Param<String> STRATEGY =
+ new StringParam(
+ "strategy",
+ "Strategy used to define the width of the bin.",
+ QUANTILE,
+ ParamValidators.inArray(UNIFORM, QUANTILE, KMEANS));
+
+ Param<Integer> NUM_BINS =
+ new IntParam("numBins", "Number of bins to produce.", 5, ParamValidators.gtEq(2));
+
+ Param<Integer> SUB_SAMPLES =
+ new IntParam(
+ "subSamples",
+ "Maximum number of samples used to fit the model.",
+ 200000,
+ ParamValidators.gtEq(2));
+
+ default String getStrategy() {
+ return get(STRATEGY);
+ }
+
+ default T setStrategy(String value) {
+ return set(STRATEGY, value);
+ }
+
+ default int getNumBins() {
+ return get(NUM_BINS);
+ }
+
+ default T setNumBins(int value) {
+ return set(NUM_BINS, value);
+ }
+
+ default int getSubSamples() {
+ return get(SUB_SAMPLES);
+ }
+
+ default T setSubSamples(Integer value) {
+ return set(SUB_SAMPLES, value);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
index 4e59851..28545e9 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
@@ -119,7 +119,7 @@ public class MinMaxScaler
* A stream operator to compute the min and max values in each partition of the input bounded
* data stream.
*/
- private static class MinMaxReduceFunctionOperator extends AbstractStreamOperator<DenseVector>
+ public static class MinMaxReduceFunctionOperator extends AbstractStreamOperator<DenseVector>
implements OneInputStreamOperator<DenseVector, DenseVector>, BoundedOneInput {
private ListState<DenseVector> minState;
private ListState<DenseVector> maxState;
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java
new file mode 100644
index 0000000..17ad6b5
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/KBinsDiscretizerTest.java
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizer;
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModel;
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerModelData;
+import org.apache.flink.ml.feature.kbinsdiscretizer.KBinsDiscretizerParams;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.table.api.Expressions.$;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests {@link KBinsDiscretizer} and {@link KBinsDiscretizerModel}. */
+public class KBinsDiscretizerTest extends AbstractTestBase {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainTable;
+ private Table testTable;
+
+ // Column0 for normal cases, column1 for constant cases, column2 for numDistinct < numBins
+ // cases.
+ private static final List<Row> TRAIN_INPUT =
+ Arrays.asList(
+ Row.of(Vectors.dense(1, 10, 0)),
+ Row.of(Vectors.dense(1, 10, 0)),
+ Row.of(Vectors.dense(1, 10, 0)),
+ Row.of(Vectors.dense(4, 10, 0)),
+ Row.of(Vectors.dense(5, 10, 0)),
+ Row.of(Vectors.dense(6, 10, 0)),
+ Row.of(Vectors.dense(7, 10, 0)),
+ Row.of(Vectors.dense(10, 10, 0)),
+ Row.of(Vectors.dense(13, 10, 3)));
+
+ private static final List<Row> TEST_INPUT =
+ Arrays.asList(
+ Row.of(Vectors.dense(-1, 0, 0)),
+ Row.of(Vectors.dense(1, 1, 1)),
+ Row.of(Vectors.dense(1.5, 1, 2)),
+ Row.of(Vectors.dense(5, 2, 3)),
+ Row.of(Vectors.dense(7.25, 3, 4)),
+ Row.of(Vectors.dense(13, 4, 5)),
+ Row.of(Vectors.dense(15, 4, 6)));
+
+ private static final double[][] UNIFORM_MODEL_DATA =
+ new double[][] {
+ new double[] {1, 5, 9, 13},
+ new double[] {Double.MIN_VALUE, Double.MAX_VALUE},
+ new double[] {0, 1, 2, 3}
+ };
+
+ private static final List<Row> UNIFORM_OUTPUT =
+ Arrays.asList(
+ Row.of(Vectors.dense(0, 0, 0)),
+ Row.of(Vectors.dense(0, 0, 1)),
+ Row.of(Vectors.dense(0, 0, 2)),
+ Row.of(Vectors.dense(1, 0, 2)),
+ Row.of(Vectors.dense(1, 0, 2)),
+ Row.of(Vectors.dense(2, 0, 2)),
+ Row.of(Vectors.dense(2, 0, 2)));
+
+ private static final List<Row> QUANTILE_OUTPUT =
+ Arrays.asList(
+ Row.of(Vectors.dense(0, 0, 0)),
+ Row.of(Vectors.dense(0, 0, 0)),
+ Row.of(Vectors.dense(0, 0, 0)),
+ Row.of(Vectors.dense(1, 0, 0)),
+ Row.of(Vectors.dense(2, 0, 0)),
+ Row.of(Vectors.dense(2, 0, 0)),
+ Row.of(Vectors.dense(2, 0, 0)));
+
+ private static final List<Row> KMEANS_OUTPUT =
+ Arrays.asList(
+ Row.of(Vectors.dense(0, 0, 0)),
+ Row.of(Vectors.dense(0, 0, 1)),
+ Row.of(Vectors.dense(0, 0, 2)),
+ Row.of(Vectors.dense(1, 0, 2)),
+ Row.of(Vectors.dense(1, 0, 2)),
+ Row.of(Vectors.dense(2, 0, 2)),
+ Row.of(Vectors.dense(2, 0, 2)));
+
+ private static final double TOLERANCE = 1e-7;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+ trainTable = tEnv.fromDataStream(env.fromCollection(TRAIN_INPUT)).as("input");
+ testTable = tEnv.fromDataStream(env.fromCollection(TEST_INPUT)).as("input");
+ }
+
+ @SuppressWarnings("unchecked, ConstantConditions")
+ private void verifyPredictionResult(
+ List<Row> expectedOutput, Table output, String predictionCol) throws Exception {
+ List<Row> collectedResult =
+ IteratorUtils.toList(
+ tEnv.toDataStream(output.select($(predictionCol))).executeAndCollect());
+ compareResultCollections(
+ expectedOutput,
+ collectedResult,
+ (o1, o2) ->
+ TestUtils.compare(
+ (DenseVector) o1.getField(0), (DenseVector) o2.getField(0)));
+ }
+
+ @Test
+ public void testParam() {
+ KBinsDiscretizer kBinsDiscretizer = new KBinsDiscretizer();
+
+ assertEquals("input", kBinsDiscretizer.getInputCol());
+ assertEquals(5, kBinsDiscretizer.getNumBins());
+ assertEquals("quantile", kBinsDiscretizer.getStrategy());
+ assertEquals(200000, kBinsDiscretizer.getSubSamples());
+ assertEquals("output", kBinsDiscretizer.getOutputCol());
+
+ kBinsDiscretizer
+ .setInputCol("test_input")
+ .setNumBins(10)
+ .setStrategy(KBinsDiscretizerParams.KMEANS)
+ .setSubSamples(1000)
+ .setOutputCol("test_output");
+
+ assertEquals("test_input", kBinsDiscretizer.getInputCol());
+ assertEquals(10, kBinsDiscretizer.getNumBins());
+ assertEquals("kmeans", kBinsDiscretizer.getStrategy());
+ assertEquals(1000, kBinsDiscretizer.getSubSamples());
+ assertEquals("test_output", kBinsDiscretizer.getOutputCol());
+ }
+
+ @Test
+ public void testOutputSchema() {
+ Table tempTable =
+ tEnv.fromDataStream(env.fromElements(Row.of("", "")))
+ .as("test_input", "dummy_input");
+ KBinsDiscretizer kBinsDiscretizer =
+ new KBinsDiscretizer().setInputCol("test_input").setOutputCol("test_output");
+ Table output = kBinsDiscretizer.fit(tempTable).transform(tempTable)[0];
+
+ assertEquals(
+ Arrays.asList("test_input", "dummy_input", "test_output"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testFitAndPredict() throws Exception {
+ KBinsDiscretizer kBinsDiscretizer = new KBinsDiscretizer().setNumBins(3);
+ Table output;
+
+ // Tests uniform strategy.
+ kBinsDiscretizer.setStrategy(KBinsDiscretizerParams.UNIFORM);
+ output = kBinsDiscretizer.fit(trainTable).transform(testTable)[0];
+ verifyPredictionResult(UNIFORM_OUTPUT, output, kBinsDiscretizer.getOutputCol());
+
+ // Tests quantile strategy.
+ kBinsDiscretizer.setStrategy(KBinsDiscretizerParams.QUANTILE);
+ output = kBinsDiscretizer.fit(trainTable).transform(testTable)[0];
+ verifyPredictionResult(QUANTILE_OUTPUT, output, kBinsDiscretizer.getOutputCol());
+
+ // Tests kmeans strategy.
+ kBinsDiscretizer.setStrategy(KBinsDiscretizerParams.KMEANS);
+ output = kBinsDiscretizer.fit(trainTable).transform(testTable)[0];
+ verifyPredictionResult(KMEANS_OUTPUT, output, kBinsDiscretizer.getOutputCol());
+ }
+
+ @Test
+ public void testSaveLoadAndPredict() throws Exception {
+ KBinsDiscretizer kBinsDiscretizer =
+ new KBinsDiscretizer().setNumBins(3).setStrategy(KBinsDiscretizerParams.UNIFORM);
+ kBinsDiscretizer =
+ TestUtils.saveAndReload(
+ tEnv, kBinsDiscretizer, tempFolder.newFolder().getAbsolutePath());
+
+ KBinsDiscretizerModel model = kBinsDiscretizer.fit(trainTable);
+ model = TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
+
+ assertEquals(
+ Collections.singletonList("binEdges"),
+ model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+ Table output = model.transform(testTable)[0];
+ verifyPredictionResult(UNIFORM_OUTPUT, output, kBinsDiscretizer.getOutputCol());
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testGetModelData() throws Exception {
+ KBinsDiscretizer kBinsDiscretizer =
+ new KBinsDiscretizer().setNumBins(3).setStrategy(KBinsDiscretizerParams.UNIFORM);
+ KBinsDiscretizerModel model = kBinsDiscretizer.fit(trainTable);
+ Table modelDataTable = model.getModelData()[0];
+
+ assertEquals(
+ Collections.singletonList("binEdges"),
+ modelDataTable.getResolvedSchema().getColumnNames());
+
+ List<KBinsDiscretizerModelData> collectedModelData =
+ (List<KBinsDiscretizerModelData>)
+ IteratorUtils.toList(
+ KBinsDiscretizerModelData.getModelDataStream(modelDataTable)
+ .executeAndCollect());
+ assertEquals(1, collectedModelData.size());
+
+ KBinsDiscretizerModelData modelData = collectedModelData.get(0);
+ assertEquals(UNIFORM_MODEL_DATA.length, modelData.binEdges.length);
+ for (int i = 0; i < modelData.binEdges.length; i++) {
+ assertArrayEquals(UNIFORM_MODEL_DATA[i], modelData.binEdges[i], TOLERANCE);
+ }
+ }
+
+ @Test
+ public void testSetModelData() throws Exception {
+ KBinsDiscretizer kBinsDiscretizer =
+ new KBinsDiscretizer().setNumBins(3).setStrategy(KBinsDiscretizerParams.UNIFORM);
+
+ KBinsDiscretizerModel model = kBinsDiscretizer.fit(trainTable);
+
+ KBinsDiscretizerModel newModel = new KBinsDiscretizerModel();
+ ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ newModel.setModelData(model.getModelData());
+ Table output = newModel.transform(testTable)[0];
+
+ verifyPredictionResult(UNIFORM_OUTPUT, output, kBinsDiscretizer.getOutputCol());
+ }
+
+ @Test
+ public void testFitOnEmptyData() {
+ Table emptyTable =
+ tEnv.fromDataStream(env.fromCollection(TRAIN_INPUT).filter(x -> x.getArity() == 0))
+ .as("input");
+ KBinsDiscretizerModel model = new KBinsDiscretizer().fit(emptyTable);
+ Table modelDataTable = model.getModelData()[0];
+ try {
+ modelDataTable.execute().collect().next();
+ fail();
+ } catch (Throwable e) {
+ assertEquals("The training set is empty.", ExceptionUtils.getRootCause(e).getMessage());
+ }
+ }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/kbinsdiscreteizer_example.py b/flink-ml-python/pyflink/examples/ml/feature/kbinsdiscreteizer_example.py
new file mode 100644
index 0000000..d33d613
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/kbinsdiscreteizer_example.py
@@ -0,0 +1,75 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that trains a StringIndexer model and uses it for feature
+# engineering.
+#
+# Before executing this program, please make sure you have followed Flink ML's
+# quick start guideline to set up Flink ML and Flink environment. The guideline
+# can be found at
+#
+# https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
+
+from pyflink.common import Types
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.kbinsdiscretizer import KBinsDiscretizer
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input for training and prediction.
+input_table = t_env.from_data_stream(
+ env.from_collection([
+ (Vectors.dense(1, 10, 0),),
+ (Vectors.dense(1, 10, 0),),
+ (Vectors.dense(1, 10, 0),),
+ (Vectors.dense(4, 10, 0),),
+ (Vectors.dense(5, 10, 0),),
+ (Vectors.dense(6, 10, 0),),
+ (Vectors.dense(7, 10, 0),),
+ (Vectors.dense(10, 10, 0),),
+ (Vectors.dense(13, 10, 0),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input', ],
+ [DenseVectorTypeInfo(), ])))
+
+# Creates a KBinsDiscretizer object and initializes its parameters.
+k_bins_discretizer = KBinsDiscretizer() \
+ .set_input_col('input') \
+ .set_output_col('output') \
+ .set_num_bins(3) \
+ .set_strategy('uniform')
+
+# Trains the KBinsDiscretizer Model.
+model = k_bins_discretizer.fit(input_table)
+
+# Uses the KBinsDiscretizer Model for predictions.
+output = model.transform(input_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ print('Input Value: ' + str(result[field_names.index(k_bins_discretizer.get_input_col())])
+ + '\tOutput Value: ' +
+ str(result[field_names.index(k_bins_discretizer.get_output_col())]))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/kbinsdiscretizer.py b/flink-ml-python/pyflink/ml/lib/feature/kbinsdiscretizer.py
new file mode 100644
index 0000000..03ab5e7
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/kbinsdiscretizer.py
@@ -0,0 +1,168 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+
+from pyflink.ml.core.param import IntParam, StringParam, ParamValidators
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator
+from pyflink.ml.lib.param import HasInputCol, HasOutputCol
+
+
+class _KBinsDiscretizerModelParams(
+ JavaWithParams,
+ HasInputCol,
+ HasOutputCol
+):
+ """
+ Params for :class:`KBinsDiscretizerModel`.
+ """
+
+ def __init__(self, java_params):
+ super(_KBinsDiscretizerModelParams, self).__init__(java_params)
+
+
+class _KBinsDiscretizerParams(_KBinsDiscretizerModelParams):
+ """
+ Params for :class:`KBinsDiscretizer`.
+ """
+
+ """
+ Supported options to define the widths of the bins are listed as follows.
+ <ul>
+ <li>uniform: all bins in each feature have identical widths.
+ <li>quantile: all bins in each feature have the same number of points.
+ <li>kmeans: values in each bin have the same nearest center of a 1D kmeans cluster.
+ </ul>
+ """
+ STRATEGY: StringParam = StringParam(
+ "strategy",
+ "Strategy used to define the width of the bin.",
+ 'quantile',
+ ParamValidators.in_array(['uniform', 'quantile', 'kmeans']))
+
+ NUM_BINS: IntParam = IntParam(
+ "num_bins",
+ "Number of bins to produce.",
+ 5,
+ ParamValidators.gt_eq(2)
+ )
+
+ SUB_SAMPLES: IntParam = IntParam(
+ "sub_samples",
+ "Maximum number of samples used to fit the model.",
+ 200000,
+ ParamValidators.gt_eq(2)
+ )
+
+ def __init__(self, java_params):
+ super(_KBinsDiscretizerParams, self).__init__(java_params)
+
+ def set_strategy(self, value: str):
+ return typing.cast(_KBinsDiscretizerParams, self.set(self.STRATEGY, value))
+
+ def get_strategy(self) -> str:
+ return self.get(self.STRATEGY)
+
+ def set_num_bins(self, value: int):
+ return typing.cast(_KBinsDiscretizerParams, self.set(self.NUM_BINS, value))
+
+ def get_num_bins(self) -> int:
+ return self.get(self.NUM_BINS)
+
+ def set_sub_samples(self, value: int):
+ return typing.cast(_KBinsDiscretizerParams, self.set(self.SUB_SAMPLES, value))
+
+ def get_sub_samples(self) -> int:
+ return self.get(self.SUB_SAMPLES)
+
+ @property
+ def strategy(self):
+ return self.get_strategy()
+
+ @property
+ def num_bins(self):
+ return self.get_num_bins()
+
+ @property
+ def sub_samples(self):
+ return self.get_sub_samples()
+
+
+class KBinsDiscretizerModel(JavaFeatureModel, _KBinsDiscretizerModelParams):
+ """
+ A Model which transforms continuous features into discrete features using the model data
+ computed by :class::KBinsDiscretizer.
+
+ <p>A feature value `v` should be mapped to a bin with edges as `{left, right}` if `v` is
+ in `[left, right)`. If `v` does not fall into any of the bins, it is mapped to the
+ closest bin. For example uppose the bin edges are `{-1, 0, 1}` for one column, then
+ we have two bins `{-1, 0}` and `{0, 1}`. In this case, -2 is mapped into 0-th bin,
+ 0 is mapped into the 1-st bin and 2 is mapped into the 1-st bin.
+ """
+
+ def __init__(self, java_model=None):
+ super(KBinsDiscretizerModel, self).__init__(java_model)
+
+ @classmethod
+ def _java_model_package_name(cls) -> str:
+ return "kbinsdiscretizer"
+
+ @classmethod
+ def _java_model_class_name(cls) -> str:
+ return "KBinsDiscretizerModel"
+
+
+class KBinsDiscretizer(JavaFeatureEstimator, _KBinsDiscretizerParams):
+ """
+ An Estimator which implements discretization (also known as quantization or binning) to
+ transform continuous features into discrete ones. The output values are in [0, numBins).
+
+ <p>KBinsDiscretizer implements three different binning strategies, and it can be set by {@link
+ KBinsDiscretizerParams#STRATEGY}. If the strategy is set as
+ {@link KBinsDiscretizerParams#KMEANS} or {@link KBinsDiscretizerParams#QUANTILE},
+ users should further set {@link KBinsDiscretizerParams#SUB_SAMPLES} for
+ better performance.
+
+ <p>There are several corner cases for different inputs as listed below:
+
+ <ul>
+ <li>When the input values of one column are all the same, then they should be mapped
+ to the same bin (i.e., the zero-th bin). Thus the corresponding bin edges are
+ {Double.MIN_VALUE, Double.MAX_VALUE}.
+ <li>When the number of distinct values of one column is less than the specified
+ number of bins and the {@link KBinsDiscretizerParams#STRATEGY} is set as {@link
+ KBinsDiscretizerParams#KMEANS}, we switch to {@link KBinsDiscretizerParams#UNIFORM}.
+ <li>When the width of one output bin is zero, i.e., the left edge equals to the right
+ edge of the bin, we remove it.
+ </ul>
+ """
+
+ def __init__(self):
+ super(KBinsDiscretizer, self).__init__()
+
+ @classmethod
+ def _create_model(cls, java_model) -> KBinsDiscretizerModel:
+ return KBinsDiscretizerModel(java_model)
+
+ @classmethod
+ def _java_estimator_package_name(cls) -> str:
+ return "kbinsdiscretizer"
+
+ @classmethod
+ def _java_estimator_class_name(cls) -> str:
+ return "KBinsDiscretizer"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py
new file mode 100644
index 0000000..2dacb68
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py
@@ -0,0 +1,172 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.feature.kbinsdiscretizer import KBinsDiscretizer, KBinsDiscretizerModel
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class KBinsDiscretizerTest(PyFlinkMLTestCase):
+ def setUp(self):
+ super(KBinsDiscretizerTest, self).setUp()
+ self.train_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (Vectors.dense(1, 10, 0),),
+ (Vectors.dense(1, 10, 0),),
+ (Vectors.dense(1, 10, 0),),
+ (Vectors.dense(4, 10, 0),),
+ (Vectors.dense(5, 10, 0),),
+ (Vectors.dense(6, 10, 0),),
+ (Vectors.dense(7, 10, 0),),
+ (Vectors.dense(10, 10, 0),),
+ (Vectors.dense(13, 10, 3),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input', ],
+ [DenseVectorTypeInfo(), ])))
+
+ self.predict_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (Vectors.dense(-1, 0, 0),),
+ (Vectors.dense(1, 1, 1),),
+ (Vectors.dense(1.5, 1, 2),),
+ (Vectors.dense(5, 2, 3),),
+ (Vectors.dense(7.25, 3, 4),),
+ (Vectors.dense(13, 4, 5),),
+ (Vectors.dense(15, 4, 6),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input', ],
+ [DenseVectorTypeInfo(), ])))
+
+ self.uniform_output = [
+ Vectors.dense(0, 0, 0),
+ Vectors.dense(0, 0, 1),
+ Vectors.dense(0, 0, 2),
+ Vectors.dense(1, 0, 2),
+ Vectors.dense(1, 0, 2),
+ Vectors.dense(2, 0, 2),
+ Vectors.dense(2, 0, 2),
+ ]
+
+ self.quantile_output = [
+ Vectors.dense(0, 0, 0),
+ Vectors.dense(0, 0, 0),
+ Vectors.dense(0, 0, 0),
+ Vectors.dense(1, 0, 0),
+ Vectors.dense(2, 0, 0),
+ Vectors.dense(2, 0, 0),
+ Vectors.dense(2, 0, 0),
+ ]
+
+ self.kmeans_output = [
+ Vectors.dense(0, 0, 0),
+ Vectors.dense(0, 0, 1),
+ Vectors.dense(0, 0, 2),
+ Vectors.dense(1, 0, 2),
+ Vectors.dense(1, 0, 2),
+ Vectors.dense(2, 0, 2),
+ Vectors.dense(2, 0, 2),
+ ]
+
+ def test_param(self):
+ k_bins_discretizer = KBinsDiscretizer()
+
+ self.assertEqual("input", k_bins_discretizer.input_col)
+ self.assertEqual(5, k_bins_discretizer.num_bins)
+ self.assertEqual("quantile", k_bins_discretizer.strategy)
+ self.assertEqual(200000, k_bins_discretizer.sub_samples)
+ self.assertEqual("output", k_bins_discretizer.output_col)
+
+ k_bins_discretizer \
+ .set_input_col("test_input") \
+ .set_num_bins(10) \
+ .set_strategy('kmeans') \
+ .set_sub_samples(1000) \
+ .set_output_col("test_output")
+
+ self.assertEqual("test_input", k_bins_discretizer.input_col)
+ self.assertEqual(10, k_bins_discretizer.num_bins)
+ self.assertEqual("kmeans", k_bins_discretizer.strategy)
+ self.assertEqual(1000, k_bins_discretizer.sub_samples)
+ self.assertEqual("test_output", k_bins_discretizer.output_col)
+
+ def test_output_schema(self):
+ k_bins_discretizer = KBinsDiscretizer() \
+ .set_input_col("test_input") \
+ .set_output_col("test_output")
+ input_data_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ ('', ''),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['test_input', 'dummy_input'],
+ [Types.STRING(), Types.STRING()])))
+ output = k_bins_discretizer \
+ .fit(input_data_table) \
+ .transform(input_data_table)[0]
+
+ self.assertEqual(
+ [k_bins_discretizer.input_col, 'dummy_input', k_bins_discretizer.output_col],
+ output.get_schema().get_field_names())
+
+ def verify_prediction_result(self, expected, output_table):
+ predicted_results = [result[1] for result in
+ self.t_env.to_data_stream(output_table).execute_and_collect()]
+
+ predicted_results.sort(key=lambda x: (x[0], x[1], x[2]))
+ expected.sort(key=lambda x: (x[0], x[1], x[2]))
+
+ self.assertEqual(expected, predicted_results)
+
+ def test_fit_and_predict(self):
+ k_bins_discretizer = KBinsDiscretizer().set_num_bins(3)
+
+ # Tests uniform strategy.
+ k_bins_discretizer.set_strategy('uniform')
+ output = k_bins_discretizer.fit(self.train_table).transform(self.predict_table)[0]
+ self.verify_prediction_result(self.uniform_output, output)
+
+ # Tests quantile strategy.
+ k_bins_discretizer.set_strategy('quantile')
+ output = k_bins_discretizer.fit(self.train_table).transform(self.predict_table)[0]
+ self.verify_prediction_result(self.quantile_output, output)
+
+ # Tests kmeans strategy.
+ k_bins_discretizer.set_strategy('kmeans')
+ output = k_bins_discretizer.fit(self.train_table).transform(self.predict_table)[0]
+ self.verify_prediction_result(self.kmeans_output, output)
+
+ def test_save_load_predict(self):
+ k_bins_discretizer = KBinsDiscretizer().set_num_bins(3)
+ estimator_path = os.path.join(self.temp_dir, 'test_save_load_predict_kbinsdiscretizer')
+ k_bins_discretizer.save(estimator_path)
+ k_bins_discretizer = KBinsDiscretizer.load(self.t_env, estimator_path)
+
+ model = k_bins_discretizer.fit(self.train_table)
+ model_path = os.path.join(self.temp_dir, 'test_save_load_predict_kbinsdiscretizer_model')
+ model.save(model_path)
+ self.env.execute('save_model')
+ model = KBinsDiscretizerModel.load(self.t_env, model_path)
+
+ output = model.transform(self.predict_table)[0]
+ self.verify_prediction_result(self.quantile_output, output)