You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/11/04 10:22:09 UTC
[flink-ml] branch master updated: [FLINK-29598] Add Estimator and Transformer for Imputer (#166)
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new b96ce2f [FLINK-29598] Add Estimator and Transformer for Imputer (#166)
b96ce2f is described below
commit b96ce2f369edb61e873ea23aa63caa5973ec641e
Author: JiangXin <ji...@alibaba-inc.com>
AuthorDate: Fri Nov 4 18:22:05 2022 +0800
[FLINK-29598] Add Estimator and Transformer for Imputer (#166)
This closes #166.
---
docs/content/docs/operators/feature/imputer.md | 196 +++++++++++
.../org/apache/flink/ml/param/DoubleParam.java | 10 +
.../java/org/apache/flink/ml/param/FloatParam.java | 2 +
.../java/org/apache/flink/ml/api/StageTest.java | 36 ++
.../flink/ml/examples/feature/ImputerExample.java | 78 +++++
.../flink/ml/common/param/HasRelativeError.java | 34 +-
.../flink/ml/common/util/QuantileSummary.java | 4 +-
.../apache/flink/ml/feature/imputer/Imputer.java | 335 +++++++++++++++++++
.../flink/ml/feature/imputer/ImputerModel.java | 170 ++++++++++
.../flink/ml/feature/imputer/ImputerModelData.java | 117 +++++++
.../ml/feature/imputer/ImputerModelParams.java | 36 +-
.../flink/ml/feature/imputer/ImputerParams.java | 60 ++++
.../flink/ml/common/util/QuantileSummaryTest.java | 14 +-
.../org/apache/flink/ml/feature/ImputerTest.java | 361 +++++++++++++++++++++
.../pyflink/examples/ml/feature/imputer_example.py | 68 ++++
.../pyflink/ml/core/tests/test_param.py | 16 +-
flink-ml-python/pyflink/ml/lib/feature/imputer.py | 129 ++++++++
.../pyflink/ml/lib/feature/tests/test_imputer.py | 139 ++++++++
flink-ml-python/pyflink/ml/lib/param.py | 21 ++
19 files changed, 1788 insertions(+), 38 deletions(-)
diff --git a/docs/content/docs/operators/feature/imputer.md b/docs/content/docs/operators/feature/imputer.md
new file mode 100644
index 0000000..f06bc98
--- /dev/null
+++ b/docs/content/docs/operators/feature/imputer.md
@@ -0,0 +1,196 @@
+---
+title: "Imputer"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/imputer.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions dand limitations
+under the License.
+-->
+
+## Imputer
+The imputer for completing missing values of the input columns.
+
+Missing values can be imputed using the statistics(mean, median or
+most frequent) of each column in which the missing values are located.
+The input columns should be of numeric type.
+
+__Note__ The `mean`/`median`/`most frequent` value is computed after
+filtering out missing values and null values, null values are always
+treated as missing, and so are also imputed.
+
+__Note__ The parameter `relativeError` is only effective when the strategy
+ is `median`.
+
+### Input Columns
+
+| Param name | Type | Default | Description |
+|:-----------|:-------|:--------|:------------------------|
+| inputCols | Number | `null` | Features to be imputed. |
+
+### Output Columns
+
+| Param name | Type | Default | Description |
+|:-----------|:-------|:--------|:------------------|
+| outputCols | Double | `null` | Imputed features. |
+
+### Parameters
+
+Below are the parameters required by `ImputerModel`.
+
+| Key | Default | Type | Required | Description |
+|:--------------|:-------------|:------------|:---------|:-------------------------------------------------------------------------------------------|
+| inputCols | `null` | String[] | yes | Input column names. |
+| outputCols | `null` | String[] | yes | Output column names. |
+| missingValue | `Double.NaN` | Double | no | The placeholder for the missing values. All occurrences of missing values will be imputed. |
+
+`Imputer` needs parameters above and also below.
+
+| Key | Default | Type | Required | Description |
+|:--------------|:-------------|:------------|:---------|:------------------------------------------------------------------------------|
+| strategy | `"mean"` | String | no | The imputation strategy. Supported values: 'mean', 'median', 'most_frequent'. |
+| relativeError | `0.001` | Double | no | The relative target precision for the approximate quantile algorithm. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import org.apache.flink.ml.feature.imputer.Imputer;
+import org.apache.flink.ml.feature.imputer.ImputerModel;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+
+/** Simple program that trains a {@link Imputer} model and uses it for feature engineering. */
+public class ImputerExample {
+
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input training and prediction data.
+ DataStream<Row> trainStream =
+ env.fromElements(
+ Row.of(Double.NaN, 9.0),
+ Row.of(1.0, 9.0),
+ Row.of(1.5, 9.0),
+ Row.of(2.5, Double.NaN),
+ Row.of(5.0, 5.0),
+ Row.of(5.0, 4.0));
+ Table trainTable = tEnv.fromDataStream(trainStream).as("input1", "input2");
+
+ // Creates an Imputer object and initialize its parameters
+ Imputer imputer =
+ new Imputer()
+ .setInputCols("input1", "input2")
+ .setOutputCols("output1", "output2")
+ .setStrategy("mean")
+ .setMissingValue(Double.NaN);
+
+ // Trains the Imputer model.
+ ImputerModel model = imputer.fit(trainTable);
+
+ // Uses the Imputer model for predictions.
+ Table outputTable = model.transform(trainTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+ double[] inputValues = new double[imputer.getInputCols().length];
+ double[] outputValues = new double[imputer.getInputCols().length];
+ for (int i = 0; i < inputValues.length; i++) {
+ inputValues[i] = (double) row.getField(imputer.getInputCols()[i]);
+ outputValues[i] = (double) row.getField(imputer.getOutputCols()[i]);
+ }
+ System.out.printf(
+ "Input Values: %s\tOutput Values: %s\n",
+ Arrays.toString(inputValues), Arrays.toString(outputValues));
+ }
+ }
+}
+```
+
+{{< /tab>}}
+
+{{< tab "Python">}}
+
+```python
+
+# Simple program that creates an Imputer instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.imputer import Imputer
+from pyflink.table import StreamTableEnvironment
+
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# create a StreamTableEnvironment
+t_env = StreamTableEnvironment.create(env)
+
+# generate input training and prediction data
+train_data = t_env.from_data_stream(
+ env.from_collection([
+ (float('NaN'), 9.0,),
+ (1.0, 9.0,),
+ (1.5, 7.0,),
+ (1.5, float('NaN'),),
+ (4.0, 5.0,),
+ (None, 4.0,),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input1', 'input2'],
+ [Types.DOUBLE(), Types.DOUBLE()])
+ ))
+
+# Creates an Imputer object and initializes its parameters.
+imputer = Imputer()\
+ .set_input_cols('input1', 'input2')\
+ .set_output_cols('output1', 'output2')\
+ .set_strategy('mean')\
+ .set_missing_value(float('NaN'))
+
+# Trains the Imputer Model.
+model = imputer.fit(train_data)
+
+# Uses the Imputer Model for predictions.
+output = model.transform(train_data)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ input_values = []
+ output_values = []
+ for i in range(len(imputer.get_input_cols())):
+ input_values.append(result[field_names.index(imputer.get_input_cols()[i])])
+ output_values.append(result[field_names.index(imputer.get_output_cols()[i])])
+ print('Input Values: ' + str(input_values) + '\tOutput Values: ' + str(output_values))
+```
+
+{{< /tab>}}
+
+{{< /tabs>}}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/param/DoubleParam.java b/flink-ml-core/src/main/java/org/apache/flink/ml/param/DoubleParam.java
index f6d4911..f94fa0e 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/param/DoubleParam.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/param/DoubleParam.java
@@ -18,6 +18,8 @@
package org.apache.flink.ml.param;
+import java.io.IOException;
+
/** Class for the double parameter. */
public class DoubleParam extends Param<Double> {
@@ -32,4 +34,12 @@ public class DoubleParam extends Param<Double> {
public DoubleParam(String name, String description, Double defaultValue) {
this(name, description, defaultValue, ParamValidators.alwaysTrue());
}
+
+ @Override
+ public Double jsonDecode(Object json) throws IOException {
+ if (json instanceof String) {
+ return Double.valueOf((String) json);
+ }
+ return (Double) json;
+ }
}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/param/FloatParam.java b/flink-ml-core/src/main/java/org/apache/flink/ml/param/FloatParam.java
index f86ecfe..2cb639d 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/param/FloatParam.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/param/FloatParam.java
@@ -36,6 +36,8 @@ public class FloatParam extends Param<Float> {
public Float jsonDecode(Object json) throws IOException {
if (json instanceof Double) {
return ((Double) json).floatValue();
+ } else if (json instanceof String) {
+ return Float.valueOf((String) json);
}
return (Float) json;
}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
index 5d1a707..a58a0ef 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/api/StageTest.java
@@ -88,9 +88,15 @@ public class StageTest {
Param<Float> FLOAT_PARAM =
new FloatParam("floatParam", "Description", 3.0f, ParamValidators.lt(100));
+ Param<Float> SPECIAL_FLOAT_PARAM =
+ new FloatParam("specialFloatParam", "Description", Float.NaN);
+
Param<Double> DOUBLE_PARAM =
new DoubleParam("doubleParam", "Description", 4.0, ParamValidators.lt(100));
+ Param<Double> SPECIAL_DOUBLE_PARAM =
+ new DoubleParam("specialDoubleParam", "Description", Double.NaN);
+
Param<String> STRING_PARAM = new StringParam("stringParam", "Description", "5");
Param<Integer[]> INT_ARRAY_PARAM =
@@ -449,6 +455,36 @@ public class StageTest {
loadedStage.get(MyParams.WINDOWS_PARAM));
}
+ @Test
+ public void testSaveLoadWithSpecialParams() throws IOException {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ MyStage stage = new MyStage();
+ stage.set(stage.paramWithNullDefault, 1);
+
+ stage.set(MyParams.SPECIAL_FLOAT_PARAM, Float.NaN);
+ stage.set(MyParams.SPECIAL_DOUBLE_PARAM, Double.NaN);
+ Stage<?> loadedStage = validateStageSaveLoad(tEnv, stage, Collections.emptyMap());
+ Assert.assertEquals(Float.NaN, loadedStage.get(MyParams.SPECIAL_FLOAT_PARAM), 0.0001);
+ Assert.assertEquals(Double.NaN, loadedStage.get(MyParams.SPECIAL_DOUBLE_PARAM), 0.0001);
+
+ stage.set(MyParams.SPECIAL_FLOAT_PARAM, Float.POSITIVE_INFINITY);
+ stage.set(MyParams.SPECIAL_DOUBLE_PARAM, Double.POSITIVE_INFINITY);
+ loadedStage = validateStageSaveLoad(tEnv, stage, Collections.emptyMap());
+ Assert.assertEquals(
+ Float.POSITIVE_INFINITY, loadedStage.get(MyParams.SPECIAL_FLOAT_PARAM), 0.0001);
+ Assert.assertEquals(
+ Double.POSITIVE_INFINITY, loadedStage.get(MyParams.SPECIAL_DOUBLE_PARAM), 0.0001);
+
+ stage.set(MyParams.SPECIAL_FLOAT_PARAM, Float.NEGATIVE_INFINITY);
+ stage.set(MyParams.SPECIAL_DOUBLE_PARAM, Double.NEGATIVE_INFINITY);
+ loadedStage = validateStageSaveLoad(tEnv, stage, Collections.emptyMap());
+ Assert.assertEquals(
+ Float.NEGATIVE_INFINITY, loadedStage.get(MyParams.SPECIAL_FLOAT_PARAM), 0.0001);
+ Assert.assertEquals(
+ Double.NEGATIVE_INFINITY, loadedStage.get(MyParams.SPECIAL_DOUBLE_PARAM), 0.0001);
+ }
+
@Test
public void testStageSaveLoadWithParamOverrides() throws IOException {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/ImputerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/ImputerExample.java
new file mode 100644
index 0000000..62cf781
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/ImputerExample.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.imputer.Imputer;
+import org.apache.flink.ml.feature.imputer.ImputerModel;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+
+/** Simple program that trains a {@link Imputer} model and uses it for feature engineering. */
+public class ImputerExample {
+
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input training and prediction data.
+ DataStream<Row> trainStream =
+ env.fromElements(
+ Row.of(Double.NaN, 9.0),
+ Row.of(1.0, 9.0),
+ Row.of(1.5, 9.0),
+ Row.of(2.5, Double.NaN),
+ Row.of(5.0, 5.0),
+ Row.of(5.0, 4.0));
+ Table trainTable = tEnv.fromDataStream(trainStream).as("input1", "input2");
+
+ // Creates an Imputer object and initialize its parameters
+ Imputer imputer =
+ new Imputer()
+ .setInputCols("input1", "input2")
+ .setOutputCols("output1", "output2")
+ .setStrategy("mean")
+ .setMissingValue(Double.NaN);
+
+ // Trains the Imputer model.
+ ImputerModel model = imputer.fit(trainTable);
+
+ // Uses the Imputer model for predictions.
+ Table outputTable = model.transform(trainTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+ double[] inputValues = new double[imputer.getInputCols().length];
+ double[] outputValues = new double[imputer.getInputCols().length];
+ for (int i = 0; i < inputValues.length; i++) {
+ inputValues[i] = (double) row.getField(imputer.getInputCols()[i]);
+ outputValues[i] = (double) row.getField(imputer.getOutputCols()[i]);
+ }
+ System.out.printf(
+ "Input Values: %s\tOutput Values: %s\n",
+ Arrays.toString(inputValues), Arrays.toString(outputValues));
+ }
+ }
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/param/FloatParam.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasRelativeError.java
similarity index 51%
copy from flink-ml-core/src/main/java/org/apache/flink/ml/param/FloatParam.java
copy to flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasRelativeError.java
index f86ecfe..e3c308b 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/param/FloatParam.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasRelativeError.java
@@ -16,27 +16,27 @@
* limitations under the License.
*/
-package org.apache.flink.ml.param;
+package org.apache.flink.ml.common.param;
-import java.io.IOException;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
-/** Class for the float parameter. */
-public class FloatParam extends Param<Float> {
+/** Interface for shared param relativeError. */
+public interface HasRelativeError<T> extends WithParams<T> {
+ Param<Double> RELATIVE_ERROR =
+ new DoubleParam(
+ "relativeError",
+ "The relative target precision for the approximate quantile algorithm.",
+ 0.001,
+ ParamValidators.inRange(0, 1));
- public FloatParam(
- String name, String description, Float defaultValue, ParamValidator<Float> validator) {
- super(name, Float.class, description, defaultValue, validator);
+ default double getRelativeError() {
+ return get(RELATIVE_ERROR);
}
- public FloatParam(String name, String description, Float defaultValue) {
- this(name, description, defaultValue, ParamValidators.alwaysTrue());
- }
-
- @Override
- public Float jsonDecode(Object json) throws IOException {
- if (json instanceof Double) {
- return ((Double) json).floatValue();
- }
- return (Float) json;
+ default T setFeaturesCol(double value) {
+ return set(RELATIVE_ERROR, value);
}
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/QuantileSummary.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/QuantileSummary.java
index 7ae1578..49730fe 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/QuantileSummary.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/QuantileSummary.java
@@ -102,8 +102,8 @@ public class QuantileSummary implements Serializable {
long count,
boolean compressed) {
Preconditions.checkArgument(
- relativeError > 0 && relativeError < 1,
- "An appropriate relative error must lay between 0 and 1.");
+ relativeError >= 0 && relativeError <= 1,
+ "An appropriate relative error must be in the range [0, 1].");
Preconditions.checkArgument(
compressThreshold > 0, "An compress threshold must greater than 0.");
this.relativeError = relativeError;
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java
new file mode 100644
index 0000000..b0b7e67
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.imputer;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.util.QuantileSummary;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * The imputer for completing missing values of the input columns.
+ *
+ * <p>Missing values can be imputed using the statistics(mean, median or most frequent) of each
+ * column in which the missing values are located. The input columns should be of numeric type.
+ *
+ * <p>Note that the mean/median/most_frequent value is computed after filtering out missing values.
+ * All null values in the input columns are also treated as missing, and so are imputed.
+ */
+public class Imputer implements Estimator<Imputer, ImputerModel>, ImputerParams<Imputer> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public Imputer() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public ImputerModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ Preconditions.checkArgument(
+ getInputCols().length == getOutputCols().length,
+ "Num of input columns and output columns are inconsistent.");
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+ DataStream<Row> inputData = tEnv.toDataStream(inputs[0]);
+
+ DataStream<ImputerModelData> modelData;
+ switch (getStrategy()) {
+ case MEAN:
+ modelData =
+ DataStreamUtils.aggregate(
+ inputData,
+ new MeanStrategyAggregator(getInputCols(), getMissingValue()));
+ break;
+ case MEDIAN:
+ modelData =
+ DataStreamUtils.aggregate(
+ inputData,
+ new MedianStrategyAggregator(
+ getInputCols(), getMissingValue(), getRelativeError()));
+ break;
+ case MOST_FREQUENT:
+ modelData =
+ DataStreamUtils.aggregate(
+ inputData,
+ new MostFrequentStrategyAggregator(
+ getInputCols(), getMissingValue()));
+ break;
+ default:
+ throw new RuntimeException("Unsupported strategy of Imputer: " + getStrategy());
+ }
+ ImputerModel model = new ImputerModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ /**
+ * A stream operator to compute the mean value of all input columns of the input bounded data
+ * stream.
+ */
+ private static class MeanStrategyAggregator
+ implements AggregateFunction<Row, Map<String, Tuple2<Double, Long>>, ImputerModelData> {
+
+ private final String[] columnNames;
+ private final double missingValue;
+
+ public MeanStrategyAggregator(String[] columnNames, double missingValue) {
+ this.columnNames = columnNames;
+ this.missingValue = missingValue;
+ }
+
+ @Override
+ public Map<String, Tuple2<Double, Long>> createAccumulator() {
+ Map<String, Tuple2<Double, Long>> accumulators = new HashMap<>();
+ Arrays.stream(columnNames).forEach(x -> accumulators.put(x, Tuple2.of(0.0, 0L)));
+ return accumulators;
+ }
+
+ @Override
+ public Map<String, Tuple2<Double, Long>> add(
+ Row row, Map<String, Tuple2<Double, Long>> accumulators) {
+ accumulators.forEach(
+ (col, sumAndNum) -> {
+ Object rawValue = row.getField(col);
+ if (rawValue != null) {
+ Double value = Double.valueOf(rawValue.toString());
+ if (!value.equals(missingValue) && !value.equals(Double.NaN)) {
+ sumAndNum.f0 += value;
+ sumAndNum.f1 += 1;
+ }
+ }
+ });
+ return accumulators;
+ }
+
+ @Override
+ public ImputerModelData getResult(Map<String, Tuple2<Double, Long>> map) {
+ long numRows = map.entrySet().stream().findFirst().get().getValue().f1;
+ Preconditions.checkState(
+ numRows > 0, "The training set is empty or does not contains valid data.");
+
+ Map<String, Double> surrogates = new HashMap<>();
+ map.forEach((col, sumAndNum) -> surrogates.put(col, sumAndNum.f0 / sumAndNum.f1));
+ return new ImputerModelData(surrogates);
+ }
+
+ @Override
+ public Map<String, Tuple2<Double, Long>> merge(
+ Map<String, Tuple2<Double, Long>> acc1, Map<String, Tuple2<Double, Long>> acc2) {
+ Preconditions.checkArgument(acc1.size() == acc2.size());
+
+ acc1.forEach(
+ (col, numAndSum) -> {
+ acc2.get(col).f0 += numAndSum.f0;
+ acc2.get(col).f1 += numAndSum.f1;
+ });
+ return acc2;
+ }
+ }
+
+ /**
+ * A stream operator to compute the median value of all input columns of the input bounded data
+ * stream.
+ */
+ private static class MedianStrategyAggregator
+ implements AggregateFunction<Row, Map<String, QuantileSummary>, ImputerModelData> {
+ private final String[] columnNames;
+ private final double missingValue;
+ private final double relativeError;
+
+ public MedianStrategyAggregator(
+ String[] columnNames, double missingValue, double relativeError) {
+ this.columnNames = columnNames;
+ this.missingValue = missingValue;
+ this.relativeError = relativeError;
+ }
+
+ @Override
+ public Map<String, QuantileSummary> createAccumulator() {
+ Map<String, QuantileSummary> summaries = new HashMap<>();
+ Arrays.stream(columnNames)
+ .forEach(x -> summaries.put(x, new QuantileSummary(relativeError)));
+ return summaries;
+ }
+
+ @Override
+ public Map<String, QuantileSummary> add(Row row, Map<String, QuantileSummary> summaries) {
+ summaries.forEach(
+ (col, summary) -> {
+ Object rawValue = row.getField(col);
+ if (rawValue != null) {
+ Double value = Double.valueOf(rawValue.toString());
+ if (!value.equals(missingValue) && !value.equals(Double.NaN)) {
+ summary.insert(value);
+ }
+ }
+ });
+ return summaries;
+ }
+
+ @Override
+ public ImputerModelData getResult(Map<String, QuantileSummary> summaries) {
+ Map<String, Double> surrogates = new HashMap<>();
+ summaries.forEach(
+ (col, summary) -> {
+ QuantileSummary compressed = summary.compress();
+ if (compressed.isEmpty()) {
+ throw new FlinkRuntimeException(
+ String.format(
+ "Surrogate cannot be computed. All the values in column [%s] are null, NaN or missingValue.",
+ col));
+ }
+ double median = compressed.query(0.5);
+ surrogates.put(col, median);
+ });
+ return new ImputerModelData(surrogates);
+ }
+
+ @Override
+ public Map<String, QuantileSummary> merge(
+ Map<String, QuantileSummary> acc1, Map<String, QuantileSummary> acc2) {
+ Preconditions.checkArgument(acc1.size() == acc2.size());
+
+ acc1.forEach(
+ (col, summary1) -> {
+ QuantileSummary summary2 = acc2.get(col).compress();
+ acc2.put(col, summary2.merge(summary1.compress()));
+ });
+ return acc2;
+ }
+ }
+
+ /**
+ * A stream operator to compute the most frequent value of all input columns of the input
+ * bounded data stream.
+ */
+ private static class MostFrequentStrategyAggregator
+ implements AggregateFunction<Row, Map<String, Map<Double, Long>>, ImputerModelData> {
+ private final String[] columnNames;
+ private final double missingValue;
+
+ public MostFrequentStrategyAggregator(String[] columnNames, double missingValue) {
+ this.columnNames = columnNames;
+ this.missingValue = missingValue;
+ }
+
+ @Override
+ public Map<String, Map<Double, Long>> createAccumulator() {
+ Map<String, Map<Double, Long>> accumulators = new HashMap<>();
+ Arrays.stream(columnNames).forEach(x -> accumulators.put(x, new HashMap<>()));
+ return accumulators;
+ }
+
+ @Override
+ public Map<String, Map<Double, Long>> add(
+ Row row, Map<String, Map<Double, Long>> accumulators) {
+ accumulators.forEach(
+ (col, counts) -> {
+ Object rawValue = row.getField(col);
+ if (rawValue != null) {
+ Double value = Double.valueOf(rawValue.toString());
+ if (!value.equals(missingValue) && !value.equals(Double.NaN)) {
+ if (counts.containsKey(value)) {
+ counts.put(value, counts.get(value) + 1);
+ } else {
+ counts.put(value, 1L);
+ }
+ }
+ }
+ });
+ return accumulators;
+ }
+
+ @Override
+ public ImputerModelData getResult(Map<String, Map<Double, Long>> map) {
+ long validColumns =
+ map.entrySet().stream().filter(x -> x.getValue().size() > 0).count();
+ Preconditions.checkState(
+ validColumns > 0, "The training set is empty or does not contains valid data.");
+
+ Map<String, Double> surrogates = new HashMap<>();
+ map.forEach(
+ (col, counts) -> {
+ long maxCnt = Long.MIN_VALUE;
+ double value = Double.NaN;
+ for (Map.Entry<Double, Long> entry : counts.entrySet()) {
+ if (maxCnt <= entry.getValue()) {
+ value =
+ maxCnt == entry.getValue()
+ ? Math.min(entry.getKey(), value)
+ : entry.getKey();
+ maxCnt = entry.getValue();
+ }
+ }
+ surrogates.put(col, value);
+ });
+ return new ImputerModelData(surrogates);
+ }
+
+ @Override
+ public Map<String, Map<Double, Long>> merge(
+ Map<String, Map<Double, Long>> acc1, Map<String, Map<Double, Long>> acc2) {
+ Preconditions.checkArgument(acc1.size() == acc2.size());
+
+ acc1.forEach(
+ (col, counts) -> {
+ Map<Double, Long> map = acc2.get(col);
+ counts.forEach(
+ (value, cnt) -> {
+ if (map.containsKey(value)) {
+ map.put(value, cnt + map.get(value));
+ } else {
+ map.put(value, cnt);
+ }
+ });
+ });
+ return acc2;
+ }
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static Imputer load(StreamTableEnvironment tEnv, String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerModel.java
new file mode 100644
index 0000000..fde41da
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerModel.java
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.imputer;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which replaces the missing values using the model data computed by {@link Imputer}. */
+public class ImputerModel implements Model<ImputerModel>, ImputerModelParams<ImputerModel> {
+
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private Table modelDataTable;
+
+ public ImputerModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public ImputerModel setModelData(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ modelDataTable = inputs[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ ImputerModelData.getModelDataStream(modelDataTable),
+ path,
+ new ImputerModelData.ModelDataEncoder());
+ }
+
+ public static ImputerModel load(StreamTableEnvironment tEnv, String path) throws IOException {
+ ImputerModel model = ReadWriteUtils.loadStageParam(path);
+ Table modelDataTable =
+ ReadWriteUtils.loadModelData(tEnv, path, new ImputerModelData.ModelDataDecoder());
+ return model.setModelData(modelDataTable);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ String[] inputCols = getInputCols();
+ String[] outputCols = getOutputCols();
+ Preconditions.checkArgument(inputCols.length == outputCols.length);
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+ DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+ DataStream<ImputerModelData> imputerModel =
+ ImputerModelData.getModelDataStream(modelDataTable);
+
+ final String broadcastModelKey = "broadcastModelKey";
+ RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ TypeInformation<?>[] outputTypes = new TypeInformation[outputCols.length];
+ Arrays.fill(outputTypes, BasicTypeInfo.DOUBLE_TYPE_INFO);
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputTypes),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCols));
+
+ DataStream<Row> output =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(dataStream),
+ Collections.singletonMap(broadcastModelKey, imputerModel),
+ inputList -> {
+ DataStream input = inputList.get(0);
+ return input.map(
+ new PredictOutputFunction(
+ getMissingValue(), inputCols, broadcastModelKey),
+ outputTypeInfo);
+ });
+ return new Table[] {tEnv.fromDataStream(output)};
+ }
+
+ /** This operator loads model data and predicts result. */
+ private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+
+ private final String[] inputCols;
+ private final String broadcastKey;
+ private final double missingValue;
+ private Map<String, Double> surrogates;
+
+ public PredictOutputFunction(double missingValue, String[] inputCols, String broadcastKey) {
+ this.missingValue = missingValue;
+ this.inputCols = inputCols;
+ this.broadcastKey = broadcastKey;
+ }
+
+ @Override
+ public Row map(Row row) throws Exception {
+ if (surrogates == null) {
+ ImputerModelData imputerModelData =
+ (ImputerModelData)
+ getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+ surrogates = imputerModelData.surrogates;
+ Arrays.stream(inputCols)
+ .forEach(
+ col ->
+ Preconditions.checkArgument(
+ surrogates.containsKey(col),
+ "Column %s is unacceptable for the Imputer model.",
+ col));
+ }
+
+ Row outputRow = new Row(inputCols.length);
+ for (int i = 0; i < inputCols.length; i++) {
+ Object value = row.getField(i);
+ if (value == null || Double.valueOf(value.toString()).equals(missingValue)) {
+ double surrogate = surrogates.get(inputCols[i]);
+ outputRow.setField(i, surrogate);
+ } else {
+ outputRow.setField(i, Double.valueOf(value.toString()));
+ }
+ }
+
+ return Row.join(row, outputRow);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerModelData.java
new file mode 100644
index 0000000..b454455
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerModelData.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.imputer;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleSerializer;
+import org.apache.flink.api.common.typeutils.base.MapSerializer;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.Map;
+
+/**
+ * Model data of {@link ImputerModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to a data stream, and
+ * classes to save/load model data.
+ */
+public class ImputerModelData {
+
+ public Map<String, Double> surrogates;
+
+ public ImputerModelData() {}
+
+ public ImputerModelData(Map<String, Double> surrogates) {
+ this.surrogates = surrogates;
+ }
+
+ /**
+ * Converts the table model to a data stream.
+ *
+ * @param modelDataTable The table model data.
+ * @return The data stream model data.
+ */
+ public static DataStream<ImputerModelData> getModelDataStream(Table modelDataTable) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+ return tEnv.toDataStream(modelDataTable)
+ .map(x -> new ImputerModelData((Map<String, Double>) x.getField(0)));
+ }
+
+ /** Encoder for {@link ImputerModelData}. */
+ public static class ModelDataEncoder implements Encoder<ImputerModelData> {
+ @Override
+ public void encode(ImputerModelData modelData, OutputStream outputStream)
+ throws IOException {
+ DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream);
+ MapSerializer<String, Double> mapSerializer =
+ new MapSerializer<>(StringSerializer.INSTANCE, DoubleSerializer.INSTANCE);
+ mapSerializer.serialize(modelData.surrogates, dataOutputView);
+ }
+ }
+
+ /** Decoder for {@link ImputerModelData}. */
+ public static class ModelDataDecoder extends SimpleStreamFormat<ImputerModelData> {
+
+ @Override
+ public Reader<ImputerModelData> createReader(
+ Configuration configuration, FSDataInputStream fsDataInputStream) {
+ return new Reader<ImputerModelData>() {
+ @Override
+ public ImputerModelData read() throws IOException {
+ DataInputView source = new DataInputViewStreamWrapper(fsDataInputStream);
+ try {
+ MapSerializer<String, Double> mapSerializer =
+ new MapSerializer<>(
+ StringSerializer.INSTANCE, DoubleSerializer.INSTANCE);
+ Map<String, Double> surrogates = mapSerializer.deserialize(source);
+ return new ImputerModelData(surrogates);
+ } catch (EOFException e) {
+ return null;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ fsDataInputStream.close();
+ }
+ };
+ }
+
+ @Override
+ public TypeInformation<ImputerModelData> getProducedType() {
+ return TypeInformation.of(ImputerModelData.class);
+ }
+ }
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/param/FloatParam.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerModelParams.java
similarity index 50%
copy from flink-ml-core/src/main/java/org/apache/flink/ml/param/FloatParam.java
copy to flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerModelParams.java
index f86ecfe..0464f65 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/param/FloatParam.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerModelParams.java
@@ -16,27 +16,31 @@
* limitations under the License.
*/
-package org.apache.flink.ml.param;
+package org.apache.flink.ml.feature.imputer;
-import java.io.IOException;
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasOutputCols;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
-/** Class for the float parameter. */
-public class FloatParam extends Param<Float> {
+/**
+ * Params for {@link ImputerModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface ImputerModelParams<T> extends HasInputCols<T>, HasOutputCols<T> {
- public FloatParam(
- String name, String description, Float defaultValue, ParamValidator<Float> validator) {
- super(name, Float.class, description, defaultValue, validator);
- }
+ Param<Double> MISSING_VALUE =
+ new DoubleParam(
+ "missingValue",
+ "The placeholder for the missing values. All occurrences of missingValue will be imputed.",
+ Double.NaN);
- public FloatParam(String name, String description, Float defaultValue) {
- this(name, description, defaultValue, ParamValidators.alwaysTrue());
+ default double getMissingValue() {
+ return get(MISSING_VALUE);
}
- @Override
- public Float jsonDecode(Object json) throws IOException {
- if (json instanceof Double) {
- return ((Double) json).floatValue();
- }
- return (Float) json;
+ default T setMissingValue(double value) {
+ return set(MISSING_VALUE, value);
}
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerParams.java
new file mode 100644
index 0000000..4ae6cbc
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerParams.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.imputer;
+
+import org.apache.flink.ml.common.param.HasRelativeError;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link Imputer}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface ImputerParams<T> extends HasRelativeError<T>, ImputerModelParams<T> {
+ String MEAN = "mean";
+ String MEDIAN = "median";
+ String MOST_FREQUENT = "most_frequent";
+
+ /**
+ * Supported options of the imputation strategy.
+ *
+ * <ul>
+ * <li>mean: replace missing values using the mean along each column.
+ * <li>median: replace missing values using the median along each column.
+ * <li>most_frequent: replace missing using the most frequent value along each column. If
+ * there is more than one such value, only the smallest is returned.
+ * </ul>
+ */
+ Param<String> STRATEGY =
+ new StringParam(
+ "strategy",
+ "The imputation strategy.",
+ MEAN,
+ ParamValidators.inArray(MEAN, MEDIAN, MOST_FREQUENT));
+
+ default String getStrategy() {
+ return get(STRATEGY);
+ }
+
+ default T setStrategy(String value) {
+ return set(STRATEGY, value);
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/QuantileSummaryTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/QuantileSummaryTest.java
index b972b5c..0bd1dc4 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/QuantileSummaryTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/QuantileSummaryTest.java
@@ -76,7 +76,10 @@ public class QuantileSummaryTest {
+ Arrays.stream(data).filter(x -> x < approx).count())
/ 2.0);
double lower = Math.floor((percentile - summary.getRelativeError()) * data.length);
- double upper = Math.ceil((percentile + summary.getRelativeError()) * data.length);
+ double upper =
+ summary.getRelativeError() == 0
+ ? Math.ceil((percentile + summary.getRelativeError()) * data.length) + 1
+ : Math.ceil((percentile + summary.getRelativeError()) * data.length);
String errMessage =
String.format(
"Rank not in [%s, %s], percentile: %s, approx returned: %s",
@@ -114,6 +117,15 @@ public class QuantileSummaryTest {
}
}
+ @Test
+ public void testNoRelativeError() {
+ for (double[] data : datasets) {
+ QuantileSummary summary = buildSummary(data, 0.0);
+ double[] percentiles = {0, 0.01, 0.1, 0.25, 0.75, 0.5, 0.9, 0.99, 1};
+ checkQuantiles(data, percentiles, summary);
+ }
+ }
+
@Test
public void testOnEmptyDataset() {
double[] data = new double[0];
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ImputerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ImputerTest.java
new file mode 100644
index 0000000..29288ed
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ImputerTest.java
@@ -0,0 +1,361 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.imputer.Imputer;
+import org.apache.flink.ml.feature.imputer.ImputerModel;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.ml.feature.imputer.ImputerParams.MEAN;
+import static org.apache.flink.ml.feature.imputer.ImputerParams.MEDIAN;
+import static org.apache.flink.ml.feature.imputer.ImputerParams.MOST_FREQUENT;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests {@link Imputer} and {@link ImputerModel}. */
+public class ImputerTest extends AbstractTestBase {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainDataTable;
+ private Table predictDataTable;
+
+ private static final double EPS = 1.0e-5;
+ private static final List<Row> TRAIN_DATA =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Double.NaN, 9.0, 1, 9.0f),
+ Row.of(1.0, 9.0, null, 9.0f),
+ Row.of(1.5, 7.0, 1, 7.0f),
+ Row.of(1.5, Double.NaN, 2, Float.NaN),
+ Row.of(4.0, 5.0, 4, 5.0f),
+ Row.of(null, 4.0, null, 4.0f)));
+
+ private static final List<Row> EXPECTED_MEAN_STRATEGY_OUTPUT =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(2.0, 9.0, 1.0, 9.0),
+ Row.of(1.0, 9.0, 2.0, 9.0),
+ Row.of(1.5, 7.0, 1.0, 7.0),
+ Row.of(1.5, 6.8, 2.0, 6.8),
+ Row.of(4.0, 5.0, 4.0, 5.0),
+ Row.of(2.0, 4.0, 2.0, 4.0)));
+
+ private static final List<Row> EXPECTED_MEDIAN_STRATEGY_OUTPUT =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(1.5, 9.0, 1.0, 9.0),
+ Row.of(1.0, 9.0, 1.0, 9.0),
+ Row.of(1.5, 7.0, 1.0, 7.0),
+ Row.of(1.5, 7.0, 2.0, 7.0),
+ Row.of(4.0, 5.0, 4.0, 5.0),
+ Row.of(1.5, 4.0, 1.0, 4.0)));
+
+ private static final List<Row> EXPECTED_MOST_FREQUENT_STRATEGY_OUTPUT =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(1.5, 9.0, 1.0, 9.0),
+ Row.of(1.0, 9.0, 1.0, 9.0),
+ Row.of(1.5, 7.0, 1.0, 7.0),
+ Row.of(1.5, 9.0, 2.0, 9.0),
+ Row.of(4.0, 5.0, 4.0, 5.0),
+ Row.of(1.5, 4.0, 1.0, 4.0)));
+
+ private static final Map<String, List<Row>> strategyAndExpectedOutputs =
+ new HashMap<String, List<Row>>() {
+ {
+ put(MEAN, EXPECTED_MEAN_STRATEGY_OUTPUT);
+ put(MEDIAN, EXPECTED_MEDIAN_STRATEGY_OUTPUT);
+ put(MOST_FREQUENT, EXPECTED_MOST_FREQUENT_STRATEGY_OUTPUT);
+ }
+ };
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ trainDataTable =
+ tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("f1", "f2", "f3", "f4");
+ predictDataTable =
+ tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("f1", "f2", "f3", "f4");
+ }
+
+ @SuppressWarnings("unchecked")
+ private static void verifyPredictionResult(
+ Table output, List<String> outputCols, List<Row> expected) throws Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+ DataStream<Row> outputDataStream = tEnv.toDataStream(output);
+ List<Row> result = IteratorUtils.toList(outputDataStream.executeAndCollect());
+ result =
+ result.stream()
+ .map(
+ row -> {
+ Row outputRow = new Row(outputCols.size());
+ for (int i = 0; i < outputCols.size(); i++) {
+ outputRow.setField(i, row.getField(outputCols.get(i)));
+ }
+ return outputRow;
+ })
+ .collect(Collectors.toList());
+ compareResultCollections(
+ expected,
+ result,
+ (row1, row2) -> {
+ int arity = Math.min(row1.getArity(), row2.getArity());
+ for (int i = 0; i < arity; i++) {
+ int cmp =
+ String.valueOf(row1.getField(i))
+ .compareTo(String.valueOf(row2.getField(i)));
+ if (cmp != 0) {
+ return cmp;
+ }
+ }
+ return 0;
+ });
+ }
+
+ @Test
+ public void testParam() {
+ Imputer imputer =
+ new Imputer()
+ .setInputCols("f1", "f2", "f3", "f4")
+ .setOutputCols("o1", "o2", "o3", "o4");
+ assertArrayEquals(new String[] {"f1", "f2", "f3", "f4"}, imputer.getInputCols());
+ assertArrayEquals(new String[] {"o1", "o2", "o3", "o4"}, imputer.getOutputCols());
+ assertEquals(MEAN, imputer.getStrategy());
+ assertEquals(Double.NaN, imputer.getMissingValue(), EPS);
+ assertEquals(0.001, imputer.getRelativeError(), EPS);
+
+ imputer.setMissingValue(0.0).setStrategy(MEDIAN);
+ assertEquals(MEDIAN, imputer.getStrategy());
+ assertEquals(0.0, imputer.getMissingValue(), EPS);
+ }
+
+ @Test
+ public void testOutputSchema() {
+ Imputer imputer =
+ new Imputer()
+ .setInputCols("f1", "f2", "f3", "f4")
+ .setOutputCols("o1", "o2", "o3", "o4");
+ ImputerModel model = imputer.fit(trainDataTable);
+ Table output = model.transform(predictDataTable)[0];
+ assertEquals(
+ Arrays.asList("f1", "f2", "f3", "f4", "o1", "o2", "o3", "o4"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testFitAndPredict() throws Exception {
+ for (Map.Entry<String, List<Row>> entry : strategyAndExpectedOutputs.entrySet()) {
+ Imputer imputer =
+ new Imputer()
+ .setInputCols("f1", "f2", "f3", "f4")
+ .setOutputCols("o1", "o2", "o3", "o4")
+ .setStrategy(entry.getKey());
+ ImputerModel model = imputer.fit(trainDataTable);
+ Table output = model.transform(predictDataTable)[0];
+ verifyPredictionResult(output, Arrays.asList("o1", "o2", "o3", "o4"), entry.getValue());
+ }
+ }
+
+ @Test
+ public void testSaveLoadAndPredict() throws Exception {
+ Imputer imputer =
+ new Imputer()
+ .setInputCols("f1", "f2", "f3", "f4")
+ .setOutputCols("o1", "o2", "o3", "o4");
+ Imputer loadedImputer =
+ TestUtils.saveAndReload(tEnv, imputer, tempFolder.newFolder().getAbsolutePath());
+ ImputerModel model = loadedImputer.fit(trainDataTable);
+ ImputerModel loadedModel =
+ TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
+ assertEquals(
+ Collections.singletonList("surrogates"),
+ model.getModelData()[0].getResolvedSchema().getColumnNames());
+ Table output = loadedModel.transform(predictDataTable)[0];
+ verifyPredictionResult(
+ output, Arrays.asList(imputer.getOutputCols()), EXPECTED_MEAN_STRATEGY_OUTPUT);
+ }
+
+ @Test
+ public void testFitOnEmptyData() {
+ Table emptyTable =
+ tEnv.fromDataStream(env.fromCollection(TRAIN_DATA).filter(x -> x.getArity() == 0))
+ .as("f1", "f2", "f3", "f4");
+
+ strategyAndExpectedOutputs.remove(MEDIAN);
+ for (Map.Entry<String, List<Row>> entry : strategyAndExpectedOutputs.entrySet()) {
+ Imputer imputer =
+ new Imputer()
+ .setInputCols("f1", "f2", "f3", "f4")
+ .setOutputCols("o1", "o2", "o3", "o4")
+ .setStrategy(entry.getKey());
+ ImputerModel model = imputer.fit(emptyTable);
+ Table modelDataTable = model.getModelData()[0];
+ try {
+ modelDataTable.execute().print();
+ fail();
+ } catch (Throwable e) {
+ assertEquals(
+ "The training set is empty or does not contains valid data.",
+ ExceptionUtils.getRootCause(e).getMessage());
+ }
+ }
+ }
+
+ @Test
+ public void testNoValidDataOnMedianStrategy() {
+ final List<Row> trainData =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Double.NaN, Float.NaN),
+ Row.of(null, null),
+ Row.of(1.0, 1.0f)));
+ trainDataTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("f1", "f2");
+ Imputer imputer =
+ new Imputer()
+ .setInputCols("f1", "f2")
+ .setOutputCols("o1", "o2")
+ .setStrategy(MEDIAN)
+ .setMissingValue(1.0);
+ ImputerModel model = imputer.fit(trainDataTable);
+ Table modelDataTable = model.getModelData()[0];
+ try {
+ modelDataTable.execute().print();
+ fail();
+ } catch (Throwable e) {
+ assertEquals(
+ "Surrogate cannot be computed. All the values in column [f1] are null, NaN or missingValue.",
+ ExceptionUtils.getRootCause(e).getMessage());
+ }
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testMultipleModeOnMostFrequentStrategy() throws Exception {
+ final List<Row> trainData =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(1.0, 2.0),
+ Row.of(1.0, 2.0),
+ Row.of(2.0, 1.0),
+ Row.of(2.0, 1.0)));
+ trainDataTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("f1", "f2");
+ Imputer imputer =
+ new Imputer()
+ .setInputCols("f1", "f2")
+ .setOutputCols("o1", "o2")
+ .setStrategy(MOST_FREQUENT);
+ ImputerModel model = imputer.fit(trainDataTable);
+ Table modelData = model.getModelData()[0];
+ DataStream<Row> output = tEnv.toDataStream(modelData);
+ List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect());
+ Map<String, Double> surrogates = (Map<String, Double>) modelRows.get(0).getField(0);
+ assert surrogates != null;
+ assertEquals(1.0, surrogates.get("f1"), EPS);
+ assertEquals(1.0, surrogates.get("f2"), EPS);
+ }
+
+ @Test
+ public void testInconsistentInputsAndOutputs() {
+ Imputer imputer =
+ new Imputer().setInputCols("f1", "f2", "f3", "f4").setOutputCols("o1", "o2", "o3");
+ try {
+ imputer.fit(trainDataTable);
+ fail();
+ } catch (Throwable e) {
+ assertEquals(
+ "Num of input columns and output columns are inconsistent.", e.getMessage());
+ }
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testGetModelData() throws Exception {
+ Imputer imputer =
+ new Imputer()
+ .setInputCols("f1", "f2", "f3", "f4")
+ .setOutputCols("o1", "o2", "o3", "o4");
+ ImputerModel model = imputer.fit(trainDataTable);
+ Table modelData = model.getModelData()[0];
+ assertEquals(
+ Collections.singletonList("surrogates"),
+ modelData.getResolvedSchema().getColumnNames());
+ DataStream<Row> output = tEnv.toDataStream(modelData);
+ List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect());
+ Map<String, Double> surrogates = (Map<String, Double>) modelRows.get(0).getField(0);
+ assert surrogates != null;
+ assertEquals(2.0, surrogates.get("f1"), EPS);
+ assertEquals(6.8, surrogates.get("f2"), EPS);
+ assertEquals(2.0, surrogates.get("f3"), EPS);
+ assertEquals(6.8, surrogates.get("f4"), EPS);
+ }
+
+ @Test
+ public void testSetModelData() throws Exception {
+ Imputer imputer =
+ new Imputer()
+ .setInputCols("f1", "f2", "f3", "f4")
+ .setOutputCols("o1", "o2", "o3", "o4");
+ ImputerModel modelA = imputer.fit(trainDataTable);
+
+ Table modelData = modelA.getModelData()[0];
+ ImputerModel modelB =
+ new ImputerModel()
+ .setModelData(modelData)
+ .setInputCols("f1", "f2", "f3", "f4")
+ .setOutputCols("o1", "o2", "o3", "o4");
+ Table output = modelB.transform(predictDataTable)[0];
+ verifyPredictionResult(
+ output, Arrays.asList(imputer.getOutputCols()), EXPECTED_MEAN_STRATEGY_OUTPUT);
+ }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/imputer_example.py b/flink-ml-python/pyflink/examples/ml/feature/imputer_example.py
new file mode 100644
index 0000000..1f17536
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/imputer_example.py
@@ -0,0 +1,68 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates an Imputer instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.imputer import Imputer
+from pyflink.table import StreamTableEnvironment
+
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# create a StreamTableEnvironment
+t_env = StreamTableEnvironment.create(env)
+
+# generate input training and prediction data
+train_data = t_env.from_data_stream(
+ env.from_collection([
+ (float('NaN'), 9.0,),
+ (1.0, 9.0,),
+ (1.5, 7.0,),
+ (1.5, float('NaN'),),
+ (4.0, 5.0,),
+ (None, 4.0,),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input1', 'input2'],
+ [Types.DOUBLE(), Types.DOUBLE()])
+ ))
+
+# Creates an Imputer object and initializes its parameters.
+imputer = Imputer()\
+ .set_input_cols('input1', 'input2')\
+ .set_output_cols('output1', 'output2')\
+ .set_strategy('mean')\
+ .set_missing_value(float('NaN'))
+
+# Trains the Imputer Model.
+model = imputer.fit(train_data)
+
+# Uses the Imputer Model for predictions.
+output = model.transform(train_data)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ input_values = []
+ output_values = []
+ for i in range(len(imputer.get_input_cols())):
+ input_values.append(result[field_names.index(imputer.get_input_cols()[i])])
+ output_values.append(result[field_names.index(imputer.get_output_cols()[i])])
+ print('Input Values: ' + str(input_values) + '\tOutput Values: ' + str(output_values))
diff --git a/flink-ml-python/pyflink/ml/core/tests/test_param.py b/flink-ml-python/pyflink/ml/core/tests/test_param.py
index f8393a8..4b1e40e 100644
--- a/flink-ml-python/pyflink/ml/core/tests/test_param.py
+++ b/flink-ml-python/pyflink/ml/core/tests/test_param.py
@@ -22,7 +22,7 @@ from pyflink.ml.core.param import Param
from pyflink.ml.lib.param import HasDistanceMeasure, HasFeaturesCol, HasGlobalBatchSize, \
HasHandleInvalid, HasInputCols, HasLabelCol, HasLearningRate, HasMaxIter, HasMultiClass, \
HasOutputCols, HasPredictionCol, HasRawPredictionCol, HasReg, HasSeed, HasTol, HasWeightCol, \
- HasWindows
+ HasWindows, HasRelativeError
from pyflink.ml.core.windows import GlobalWindows, CountTumblingWindows
@@ -30,7 +30,7 @@ from pyflink.ml.core.windows import GlobalWindows, CountTumblingWindows
class TestParams(HasDistanceMeasure, HasFeaturesCol, HasGlobalBatchSize, HasHandleInvalid,
HasInputCols, HasLabelCol, HasLearningRate, HasMaxIter, HasMultiClass,
HasOutputCols, HasPredictionCol, HasRawPredictionCol, HasReg, HasSeed, HasTol,
- HasWeightCol, HasWindows):
+ HasWeightCol, HasWindows, HasRelativeError):
def __init__(self):
self._param_map = {}
@@ -215,3 +215,15 @@ class ParamTests(unittest.TestCase):
param.set_windows(CountTumblingWindows.of(100))
self.assertEqual(param.get_windows(), CountTumblingWindows.of(100))
+
+ def test_relative_error(self):
+ param = TestParams()
+ relative_error = param.RELATIVE_ERROR
+ self.assertEqual(relative_error.name, "relative_error")
+ self.assertEqual(relative_error.description,
+ "The relative target precision for the approximate"
+ " quantile algorithm.")
+ self.assertEqual(relative_error.default_value, 0.001)
+
+ param.set_relative_error(0.1)
+ self.assertEqual(param.get_relative_error(), 0.1)
diff --git a/flink-ml-python/pyflink/ml/lib/feature/imputer.py b/flink-ml-python/pyflink/ml/lib/feature/imputer.py
new file mode 100644
index 0000000..8aa749c
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/imputer.py
@@ -0,0 +1,129 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+
+from pyflink.ml.core.param import FloatParam, StringParam, ParamValidators
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator
+from pyflink.ml.lib.param import HasInputCols, HasOutputCols, HasRelativeError
+
+
+class _ImputerModelParams(
+ JavaWithParams,
+ HasInputCols,
+ HasOutputCols,
+ HasRelativeError
+):
+ """
+ Params for :class:`ImputerModel`.
+ """
+ MISSING_VALUE: FloatParam = FloatParam(
+ "missing_value",
+ "The placeholder for the missing values. All occurrences of missing value will be imputed.",
+ float("NaN")
+ )
+
+ def __init__(self, java_params):
+ super(_ImputerModelParams, self).__init__(java_params)
+
+ def set_missing_value(self, value: float):
+ return typing.cast(_ImputerModelParams, self.set(self.MISSING_VALUE, value))
+
+ def get_missing_value(self):
+ return self.get(self.MISSING_VALUE)
+
+ @property
+ def missing_value(self):
+ return self.get_missing_value()
+
+
+class _ImputerParams(_ImputerModelParams):
+ """
+ Params for :class:`Imputer`.
+ """
+
+ """
+ Supported options of the imputation strategy.
+ <ul>
+ <li>mean: replace missing values using the mean along each column.
+ <li>median: replace missing values using the median along each column.
+ <li>most_frequent: replace missing using the most frequent value along each column.
+ If there is more than one such value, only the smallest is returned.
+ </ul>
+ """
+ STRATEGY: StringParam = StringParam(
+ "strategy",
+ "The imputation strategy.",
+ 'mean',
+ ParamValidators.in_array(['mean', 'median', 'most_frequent']))
+
+ def __init__(self, java_params):
+ super(_ImputerParams, self).__init__(java_params)
+
+ def set_strategy(self, value: str):
+ return typing.cast(_ImputerParams, self.set(self.STRATEGY, value))
+
+ def get_strategy(self) -> str:
+ return self.get(self.STRATEGY)
+
+ @property
+ def strategy(self):
+ return self.get_strategy()
+
+
+class ImputerModel(JavaFeatureModel, _ImputerModelParams):
+ """
+ A Model which replaces the missing values using the model data computed by Imputer.
+ """
+
+ def __init__(self, java_model=None):
+ super(ImputerModel, self).__init__(java_model)
+
+ @classmethod
+ def _java_model_package_name(cls) -> str:
+ return "imputer"
+
+ @classmethod
+ def _java_model_class_name(cls) -> str:
+ return "ImputerModel"
+
+
+class Imputer(JavaFeatureEstimator, _ImputerParams):
+ """
+ The imputer for completing missing values of the input columns.
+ Missing values can be imputed using the statistics (mean, median or most frequent) of each
+ column in which the missing values are located. The input columns should be of numeric type.
+
+ Note that the mean/median/most_frequent value is computed after filtering out missing values.
+ All null values in the input columns are treated as missing, and so are also imputed.
+ """
+
+ def __init__(self):
+ super(Imputer, self).__init__()
+
+ @classmethod
+ def _create_model(cls, java_model) -> ImputerModel:
+ return ImputerModel(java_model)
+
+ @classmethod
+ def _java_estimator_package_name(cls) -> str:
+ return "imputer"
+
+ @classmethod
+ def _java_estimator_class_name(cls) -> str:
+ return "Imputer"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_imputer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_imputer.py
new file mode 100644
index 0000000..3addeee
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_imputer.py
@@ -0,0 +1,139 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from typing import List
+
+import numpy as np
+from pyflink.table import Table
+from pyflink.common import Types, Row
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+from pyflink.ml.lib.feature.imputer import Imputer
+
+
+class ImputerTest(PyFlinkMLTestCase):
+ def setUp(self):
+ super(ImputerTest, self).setUp()
+ self.train_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (float('NaN'), 9.0, 1,),
+ (1.0, 9.0, None),
+ (1.5, 7.0, 1,),
+ (1.5, float('NaN'), 2,),
+ (4.0, 5.0, 4,),
+ (None, 4.0, None,),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['f1', 'f2', 'f3'],
+ [Types.DOUBLE(), Types.DOUBLE(), Types.INT()])
+ ))
+ self.expected_mean_strategy_output = [
+ Row(2.0, 9.0, 1.0,),
+ Row(1.0, 9.0, 2.0,),
+ Row(1.5, 7.0, 1.0,),
+ Row(1.5, 6.8, 2.0,),
+ Row(4.0, 5.0, 4.0,),
+ Row(2.0, 4.0, 2.0,),
+ ]
+ self.expected_median_strategy_output = [
+ Row(1.5, 9.0, 1.0,),
+ Row(1.0, 9.0, 1.0,),
+ Row(1.5, 7.0, 1.0,),
+ Row(1.5, 7.0, 2.0,),
+ Row(4.0, 5.0, 4.0,),
+ Row(1.5, 4.0, 1.0,),
+ ]
+ self.expected_most_frequent_strategy_output = [
+ Row(1.5, 9.0, 1.0,),
+ Row(1.0, 9.0, 1.0,),
+ Row(1.5, 7.0, 1.0,),
+ Row(1.5, 9.0, 2.0,),
+ Row(4.0, 5.0, 4.0,),
+ Row(1.5, 4.0, 1.0,),
+ ]
+ self.strategy_and_expected_outputs = {
+ 'mean': self.expected_mean_strategy_output,
+ 'median': self.expected_median_strategy_output,
+ 'most_frequent': self.expected_most_frequent_strategy_output
+ }
+
+ def test_param(self):
+ imputer = Imputer().\
+ set_input_cols('f1', 'f2', 'f3').\
+ set_output_cols('o1', 'o2', 'o3')
+
+ self.assertEqual(('f1', 'f2', 'f3'), imputer.input_cols)
+ self.assertEqual(('o1', 'o2', 'o3'), imputer.output_cols)
+ self.assertEqual('mean', imputer.strategy)
+ self.assertTrue(np.isnan(imputer.missing_value))
+
+ imputer.set_strategy('median').set_missing_value(1.0)
+ self.assertEqual('median', imputer.strategy)
+ self.assertEqual(1.0, imputer.missing_value)
+
+ def test_output_schema(self):
+ imputer = Imputer().\
+ set_input_cols('f1', 'f2', 'f3').\
+ set_output_cols('o1', 'o2', 'o3')
+
+ model = imputer.fit(self.train_table)
+ output = model.transform(self.train_table)[0]
+ self.assertEqual(
+ ['f1', 'f2', 'f3', 'o1', 'o2', 'o3'],
+ output.get_schema().get_field_names())
+
+ def test_fit_and_predict(self):
+ for strategy, expected_output in self.strategy_and_expected_outputs.items():
+ imputer = Imputer().\
+ set_input_cols('f1', 'f2', 'f3').\
+ set_output_cols('o1', 'o2', 'o3').\
+ set_strategy(strategy)
+ model = imputer.fit(self.train_table)
+ output = model.transform(self.train_table)[0]
+ field_names = output.get_schema().get_field_names()
+ self.verify_output_result(
+ output, imputer.get_output_cols(), field_names, expected_output)
+
+ def test_save_load_predict(self):
+ imputer = Imputer(). \
+ set_input_cols('f1', 'f2', 'f3'). \
+ set_output_cols('o1', 'o2', 'o3')
+ reloaded_imputer = self.save_and_reload(imputer)
+ model = reloaded_imputer.fit(self.train_table)
+ reloaded_model = self.save_and_reload(model)
+ output = reloaded_model.transform(self.train_table)[0]
+ self.verify_output_result(
+ output,
+ imputer.get_output_cols(),
+ output.get_schema().get_field_names(),
+ self.expected_mean_strategy_output)
+
+ def verify_output_result(
+ self, output: Table,
+ output_cols: List[str],
+ field_names: List[str],
+ expected_result: List[Row]):
+ collected_results = [result for result in
+ self.t_env.to_data_stream(output).execute_and_collect()]
+ results = []
+ for item in collected_results:
+ item.set_field_names(field_names)
+ fields = []
+ for col in output_cols:
+ fields.append(item[col])
+ results.append(Row(*fields))
+ self.assertEqual(expected_result.sort(key=lambda x: str(x)),
+ results.sort(key=lambda x: str(x)))
diff --git a/flink-ml-python/pyflink/ml/lib/param.py b/flink-ml-python/pyflink/ml/lib/param.py
index 7bfbe2e..4ca3aa0 100644
--- a/flink-ml-python/pyflink/ml/lib/param.py
+++ b/flink-ml-python/pyflink/ml/lib/param.py
@@ -539,3 +539,24 @@ class HasWindows(WithParams, ABC):
@property
def windows(self):
return self.get(self.WINDOWS)
+
+
+class HasRelativeError(WithParams, ABC):
+ """
+ Interface for shared param relativeError.
+ """
+ RELATIVE_ERROR: Param[float] = FloatParam(
+ "relative_error",
+ "The relative target precision for the approximate quantile algorithm.",
+ 0.001,
+ ParamValidators.in_range(0.0, 1.0))
+
+ def set_relative_error(self, value: float):
+ return self.set(self.RELATIVE_ERROR, value)
+
+ def get_relative_error(self) -> float:
+ return self.get(self.RELATIVE_ERROR)
+
+ @property
+ def relative_error(self):
+ return self.get(self.RELATIVE_ERROR)