You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/11/17 07:43:05 UTC
[flink-ml] branch master updated: [FLINK-29592] Add Estimator and Transformer for RobustScaler
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new acf92f0 [FLINK-29592] Add Estimator and Transformer for RobustScaler
acf92f0 is described below
commit acf92f03fb90e58351cc033c8b82d9e775a522f4
Author: JiangXin <ji...@alibaba-inc.com>
AuthorDate: Thu Nov 17 15:43:00 2022 +0800
[FLINK-29592] Add Estimator and Transformer for RobustScaler
This closes #172.
---
.../content/docs/operators/feature/robustscaler.md | 211 +++++++++++++
.../ml/examples/feature/RobustScalerExample.java | 75 +++++
.../flink/ml/common/param/HasRelativeError.java | 2 +-
.../ml/feature/robustscaler/RobustScaler.java | 188 +++++++++++
.../ml/feature/robustscaler/RobustScalerModel.java | 179 +++++++++++
.../robustscaler/RobustScalerModelData.java | 122 ++++++++
.../robustscaler/RobustScalerModelParams.java | 56 ++++
.../feature/robustscaler/RobustScalerParams.java | 61 ++++
.../org/apache/flink/ml/feature/ImputerTest.java | 9 +-
.../apache/flink/ml/feature/RobustScalerTest.java | 345 +++++++++++++++++++++
.../ml/feature/VarianceThresholdSelectorTest.java | 11 +-
.../examples/ml/feature/robustscaler_example.py | 74 +++++
.../pyflink/ml/lib/feature/robustscaler.py | 163 ++++++++++
.../ml/lib/feature/tests/test_robustscaler.py | 135 ++++++++
14 files changed, 1625 insertions(+), 6 deletions(-)
diff --git a/docs/content/docs/operators/feature/robustscaler.md b/docs/content/docs/operators/feature/robustscaler.md
new file mode 100644
index 0000000..5cf392d
--- /dev/null
+++ b/docs/content/docs/operators/feature/robustscaler.md
@@ -0,0 +1,211 @@
+---
+title: "Robust Scaler"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/robustscaler.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## Robust Scaler
+
+Robust Scaler is an algorithm that scales features using statistics that are
+robust to outliers.
+
+This Scaler removes the median and scales the data according to the quantile
+range (defaults to IQR: Interquartile Range). The IQR is the range between
+the 1st quartile (25th quantile) and the 3rd quartile (75th quantile) but can
+be configured.
+
+Centering and scaling happen independently on each feature by computing the
+relevant statistics on the samples in the training set. Median and quantile
+range are then stored to be used on later data using the transform method.
+
+Standardization of a dataset is a common requirement for many machine learning
+estimators. Typically this is done by removing the mean and scaling to unit
+variance. However, outliers can often influence the sample mean / variance
+in a negative way. In such cases, the median and the interquartile range
+often give better results.
+
+Note that NaN values are ignored in the computation of medians and ranges.
+
+### Input Columns
+
+| Param name | Type | Default | Description |
+|:-----------|:-------|:----------|:-----------------------|
+| inputCol | Vector | `"input"` | Features to be scaled. |
+
+### Output Columns
+
+| Param name | Type | Default | Description |
+|:-----------|:-------|:-----------|:-----------------|
+| outputCol | Vector | `"output"` | Scaled features. |
+
+### Parameters
+
+Below are the parameters required by `RobustScalerModel`.
+
+| Key | Default | Type | Required | Description |
+|---------------|------------|-------------|----------|-----------------------------------------------------------------------|
+| inputCol | `"input"` | String | no | Input column name. |
+| outputCol | `"output"` | String | no | Output column name. |
+| withCentering | `false` | Boolean | no | Whether to center the data with median before scaling. |
+| withScaling | `true` | Boolean | no | Whether to scale the data to quantile range. |
+
+`RobustScaler` needs parameters above and also below.
+
+| Key | Default | Type | Required | Description |
+|---------------|--------------|-------------|----------|-----------------------------------------------------------------------|
+| lower | `0.25` | Double | no | Lower quantile to calculate quantile range. |
+| upper | `0.75` | Double | no | Upper quantile to calculate quantile range. |
+| relativeError | `0.001` | Double | no | The relative target precision for the approximate quantile algorithm. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import org.apache.flink.ml.feature.robustscaler.RobustScaler;
+import org.apache.flink.ml.feature.robustscaler.RobustScalerModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that trains a {@link RobustScaler} model and uses it for feature selection. */
+public class RobustScalerExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input training and prediction data.
+ DataStream<Row> trainStream =
+ env.fromElements(
+ Row.of(1, Vectors.dense(0.0, 0.0)),
+ Row.of(2, Vectors.dense(1.0, -1.0)),
+ Row.of(3, Vectors.dense(2.0, -2.0)),
+ Row.of(4, Vectors.dense(3.0, -3.0)),
+ Row.of(5, Vectors.dense(4.0, -4.0)),
+ Row.of(6, Vectors.dense(5.0, -5.0)),
+ Row.of(7, Vectors.dense(6.0, -6.0)),
+ Row.of(8, Vectors.dense(7.0, -7.0)),
+ Row.of(9, Vectors.dense(8.0, -8.0)));
+ Table trainTable = tEnv.fromDataStream(trainStream).as("id", "input");
+
+ // Creates a RobustScaler object and initializes its parameters.
+ RobustScaler robustScaler =
+ new RobustScaler()
+ .setLower(0.25)
+ .setUpper(0.75)
+ .setRelativeError(0.001)
+ .setWithScaling(true)
+ .setWithCentering(true);
+
+ // Trains the RobustScaler model.
+ RobustScalerModel model = robustScaler.fit(trainTable);
+
+ // Uses the RobustScaler model for predictions.
+ Table outputTable = model.transform(trainTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+ DenseVector inputValue = (DenseVector) row.getField(robustScaler.getInputCol());
+ DenseVector outputValue = (DenseVector) row.getField(robustScaler.getOutputCol());
+ System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue);
+ }
+ }
+}
+```
+
+{{< /tab>}}
+
+{{< tab "Python">}}
+
+```python
+# Simple program that creates a RobustScaler instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.table import StreamTableEnvironment
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+
+from pyflink.ml.lib.feature.robustscaler import RobustScaler
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input training and prediction data.
+train_data = t_env.from_data_stream(
+ env.from_collection([
+ (1, Vectors.dense(0.0, 0.0),),
+ (2, Vectors.dense(1.0, -1.0),),
+ (3, Vectors.dense(2.0, -2.0),),
+ (4, Vectors.dense(3.0, -3.0),),
+ (5, Vectors.dense(4.0, -4.0),),
+ (6, Vectors.dense(5.0, -5.0),),
+ (7, Vectors.dense(6.0, -6.0),),
+ (8, Vectors.dense(7.0, -7.0),),
+ (9, Vectors.dense(8.0, -8.0),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['id', 'input'],
+ [Types.INT(), DenseVectorTypeInfo()])
+ ))
+
+# Creates an RobustScaler object and initializes its parameters.
+robust_scaler = RobustScaler()\
+ .set_lower(0.25)\
+ .set_upper(0.75)\
+ .set_relative_error(0.001)\
+ .set_with_scaling(True)\
+ .set_with_centering(True)
+
+# Trains the RobustScaler Model.
+model = robust_scaler.fit(train_data)
+
+# Uses the RobustScaler Model for predictions.
+output = model.transform(train_data)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ input_index = field_names.index(robust_scaler.get_input_col())
+ output_index = field_names.index(robust_scaler.get_output_col())
+ print('Input Value: ' + str(result[input_index]) +
+ '\tOutput Value: ' + str(result[output_index]))
+
+```
+
+{{< /tab>}}
+
+{{< /tabs>}}
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RobustScalerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RobustScalerExample.java
new file mode 100644
index 0000000..04f0823
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/RobustScalerExample.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.robustscaler.RobustScaler;
+import org.apache.flink.ml.feature.robustscaler.RobustScalerModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that trains a {@link RobustScaler} model and uses it for feature selection. */
+public class RobustScalerExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input training and prediction data.
+ DataStream<Row> trainStream =
+ env.fromElements(
+ Row.of(1, Vectors.dense(0.0, 0.0)),
+ Row.of(2, Vectors.dense(1.0, -1.0)),
+ Row.of(3, Vectors.dense(2.0, -2.0)),
+ Row.of(4, Vectors.dense(3.0, -3.0)),
+ Row.of(5, Vectors.dense(4.0, -4.0)),
+ Row.of(6, Vectors.dense(5.0, -5.0)),
+ Row.of(7, Vectors.dense(6.0, -6.0)),
+ Row.of(8, Vectors.dense(7.0, -7.0)),
+ Row.of(9, Vectors.dense(8.0, -8.0)));
+ Table trainTable = tEnv.fromDataStream(trainStream).as("id", "input");
+
+ // Creates a RobustScaler object and initializes its parameters.
+ RobustScaler robustScaler =
+ new RobustScaler()
+ .setLower(0.25)
+ .setUpper(0.75)
+ .setRelativeError(0.001)
+ .setWithScaling(true)
+ .setWithCentering(true);
+
+ // Trains the RobustScaler model.
+ RobustScalerModel model = robustScaler.fit(trainTable);
+
+ // Uses the RobustScaler model for predictions.
+ Table outputTable = model.transform(trainTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+ DenseVector inputValue = (DenseVector) row.getField(robustScaler.getInputCol());
+ DenseVector outputValue = (DenseVector) row.getField(robustScaler.getOutputCol());
+ System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasRelativeError.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasRelativeError.java
index e3c308b..07323a3 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasRelativeError.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasRelativeError.java
@@ -36,7 +36,7 @@ public interface HasRelativeError<T> extends WithParams<T> {
return get(RELATIVE_ERROR);
}
- default T setFeaturesCol(double value) {
+ default T setRelativeError(double value) {
return set(RELATIVE_ERROR, value);
}
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java
new file mode 100644
index 0000000..4004915
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScaler.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.robustscaler;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.util.QuantileSummary;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * An Estimator which scales features using statistics that are robust to outliers.
+ *
+ * <p>This Scaler removes the median and scales the data according to the quantile range (defaults
+ * to IQR: Interquartile Range). The IQR is the range between the 1st quartile (25th quantile) and
+ * the 3rd quartile (75th quantile) but can be configured.
+ *
+ * <p>Centering and scaling happen independently on each feature by computing the relevant
+ * statistics on the samples in the training set. Median and quantile range are then stored to be
+ * used on later data using the transform method.
+ *
+ * <p>Standardization of a dataset is a common requirement for many machine learning estimators.
+ * Typically this is done by removing the mean and scaling to unit variance. However, outliers can
+ * often influence the sample mean / variance in a negative way. In such cases, the median and the
+ * interquartile range often give better results.
+ *
+ * <p>Note that NaN values are ignored in the computation of medians and ranges.
+ */
+public class RobustScaler
+ implements Estimator<RobustScaler, RobustScalerModel>, RobustScalerParams<RobustScaler> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public RobustScaler() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public RobustScalerModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ final String inputCol = getInputCol();
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+ DataStream<DenseVector> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ (MapFunction<Row, DenseVector>)
+ value -> ((Vector) value.getField(inputCol)).toDense());
+ DataStream<RobustScalerModelData> modelData =
+ DataStreamUtils.aggregate(
+ inputData,
+ new QuantileAggregator(getRelativeError(), getLower(), getUpper()));
+ RobustScalerModel model =
+ new RobustScalerModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ /**
+ * Computes the medians and quantile ranges of input column and builds the {@link
+ * RobustScalerModelData}.
+ */
+ private static class QuantileAggregator
+ implements AggregateFunction<DenseVector, QuantileSummary[], RobustScalerModelData> {
+
+ private final double relativeError;
+ private final double lower;
+ private final double upper;
+
+ public QuantileAggregator(double relativeError, double lower, double upper) {
+ this.relativeError = relativeError;
+ this.lower = lower;
+ this.upper = upper;
+ }
+
+ @Override
+ public QuantileSummary[] createAccumulator() {
+ return new QuantileSummary[0];
+ }
+
+ @Override
+ public QuantileSummary[] add(DenseVector denseVector, QuantileSummary[] quantileSummaries) {
+ if (quantileSummaries.length == 0) {
+ quantileSummaries = new QuantileSummary[denseVector.size()];
+ for (int i = 0; i < denseVector.size(); i++) {
+ quantileSummaries[i] = new QuantileSummary(relativeError);
+ }
+ }
+ Preconditions.checkState(
+ denseVector.size() == quantileSummaries.length,
+ "Number of features must be %s but got %s.",
+ quantileSummaries.length,
+ denseVector.size());
+
+ for (int i = 0; i < quantileSummaries.length; i++) {
+ double value = denseVector.get(i);
+ if (!Double.isNaN(value)) {
+ quantileSummaries[i] = quantileSummaries[i].insert(value);
+ }
+ }
+ return quantileSummaries;
+ }
+
+ @Override
+ public RobustScalerModelData getResult(QuantileSummary[] quantileSummaries) {
+ Preconditions.checkState(quantileSummaries.length != 0, "The training set is empty.");
+ DenseVector medianVector = new DenseVector(quantileSummaries.length);
+ DenseVector rangeVector = new DenseVector(quantileSummaries.length);
+
+ for (int i = 0; i < quantileSummaries.length; i++) {
+ QuantileSummary compressed = quantileSummaries[i].compress();
+
+ double[] quantiles = compressed.query(new double[] {0.5, lower, upper});
+ medianVector.values[i] = quantiles[0];
+ rangeVector.values[i] = quantiles[2] - quantiles[1];
+ }
+ return new RobustScalerModelData(medianVector, rangeVector);
+ }
+
+ @Override
+ public QuantileSummary[] merge(QuantileSummary[] summaries, QuantileSummary[] acc) {
+ if (summaries.length == 0) {
+ return Arrays.stream(acc)
+ .map(QuantileSummary::compress)
+ .collect(Collectors.toList())
+ .toArray(acc);
+ }
+ if (acc.length == 0) {
+ return Arrays.stream(summaries)
+ .map(QuantileSummary::compress)
+ .collect(Collectors.toList())
+ .toArray(summaries);
+ }
+ Preconditions.checkState(summaries.length == acc.length);
+
+ for (int i = 0; i < summaries.length; i++) {
+ acc[i] = acc[i].compress().merge(summaries[i].compress());
+ }
+ return acc;
+ }
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static RobustScaler load(StreamTableEnvironment tEnv, String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModel.java
new file mode 100644
index 0000000..deda6e3
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModel.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.robustscaler;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which transforms data using the model data computed by {@link RobustScaler}. */
+public class RobustScalerModel
+ implements Model<RobustScalerModel>, RobustScalerModelParams<RobustScalerModel> {
+
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private Table modelDataTable;
+
+ public RobustScalerModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+ DataStream<Row> inputStream = tEnv.toDataStream(inputs[0]);
+
+ RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+ final String broadcastModelKey = "broadcastModelKey";
+ DataStream<RobustScalerModelData> modelDataStream =
+ RobustScalerModelData.getModelDataStream(modelDataTable);
+
+ DataStream<Row> output =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(inputStream),
+ Collections.singletonMap(broadcastModelKey, modelDataStream),
+ inputList -> {
+ DataStream inputData = inputList.get(0);
+ return inputData.map(
+ new PredictOutputFunction(
+ broadcastModelKey,
+ getInputCol(),
+ getWithCentering(),
+ getWithScaling()),
+ outputTypeInfo);
+ });
+
+ return new Table[] {tEnv.fromDataStream(output)};
+ }
+
+ /** This operator loads model data and predicts result. */
+ private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+ private final String broadcastModelKey;
+ private final String inputCol;
+ private final boolean withCentering;
+ private final boolean withScaling;
+
+ private DenseVector medians;
+ private DenseVector scales;
+
+ public PredictOutputFunction(
+ String broadcastModelKey,
+ String inputCol,
+ boolean withCentering,
+ boolean withScaling) {
+ this.broadcastModelKey = broadcastModelKey;
+ this.inputCol = inputCol;
+ this.withCentering = withCentering;
+ this.withScaling = withScaling;
+ }
+
+ @Override
+ public Row map(Row row) throws Exception {
+ if (medians == null) {
+ RobustScalerModelData modelData =
+ (RobustScalerModelData)
+ getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
+ medians = modelData.medians;
+ scales =
+ new DenseVector(
+ Arrays.stream(modelData.ranges.values)
+ .map(range -> range == 0 ? 0 : 1 / range)
+ .toArray());
+ }
+ DenseVector outputVec = ((Vector) row.getField(inputCol)).clone().toDense();
+ Preconditions.checkState(
+ medians.size() == outputVec.size(),
+ "Number of features must be %s but got %s.",
+ medians.size(),
+ outputVec.size());
+
+ if (withCentering) {
+ BLAS.axpy(-1, medians, outputVec);
+ }
+ if (withScaling) {
+ BLAS.hDot(scales, outputVec);
+ }
+ return Row.join(row, Row.of(outputVec));
+ }
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ RobustScalerModelData.getModelDataStream(modelDataTable),
+ path,
+ new RobustScalerModelData.ModelDataEncoder());
+ }
+
+ public static RobustScalerModel load(StreamTableEnvironment tEnv, String path)
+ throws IOException {
+ RobustScalerModel model = ReadWriteUtils.loadStageParam(path);
+ Table modelDataTable =
+ ReadWriteUtils.loadModelData(
+ tEnv, path, new RobustScalerModelData.ModelDataDecoder());
+ return model.setModelData(modelDataTable);
+ }
+
+ @Override
+ public RobustScalerModel setModelData(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ this.modelDataTable = inputs[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelData.java
new file mode 100644
index 0000000..807fe24
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelData.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.robustscaler;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link RobustScalerModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to a data stream, and
+ * classes to save/load model data.
+ */
+public class RobustScalerModelData {
+ public DenseVector medians;
+
+ public DenseVector ranges;
+
+ public RobustScalerModelData() {}
+
+ public RobustScalerModelData(DenseVector medians, DenseVector ranges) {
+ this.medians = medians;
+ this.ranges = ranges;
+ }
+
+ /**
+ * Converts the table model to a data stream.
+ *
+ * @param modelDataTable The table model data.
+ * @return The data stream model data.
+ */
+ public static DataStream<RobustScalerModelData> getModelDataStream(Table modelDataTable) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+ return tEnv.toDataStream(modelDataTable)
+ .map(
+ x ->
+ new RobustScalerModelData(
+ (DenseVector) x.getField("medians"),
+ (DenseVector) x.getField("ranges")));
+ }
+
+ /** Data encoder for the {@link RobustScalerModel} model data. */
+ public static class ModelDataEncoder implements Encoder<RobustScalerModelData> {
+ private final DenseVectorSerializer serializer = new DenseVectorSerializer();
+
+ @Override
+ public void encode(RobustScalerModelData modelData, OutputStream outputStream)
+ throws IOException {
+ DataOutputViewStreamWrapper outputViewStreamWrapper =
+ new DataOutputViewStreamWrapper(outputStream);
+ serializer.serialize(modelData.medians, outputViewStreamWrapper);
+ serializer.serialize(modelData.ranges, outputViewStreamWrapper);
+ }
+ }
+
+ /** Data decoder for the {@link RobustScalerModel} model data. */
+ public static class ModelDataDecoder extends SimpleStreamFormat<RobustScalerModelData> {
+
+ @Override
+ public Reader<RobustScalerModelData> createReader(
+ Configuration configuration, FSDataInputStream inputStream) throws IOException {
+ return new Reader<RobustScalerModelData>() {
+ private final DenseVectorSerializer serializer = new DenseVectorSerializer();
+
+ @Override
+ public RobustScalerModelData read() throws IOException {
+ DataInputViewStreamWrapper inputViewStreamWrapper =
+ new DataInputViewStreamWrapper(inputStream);
+ try {
+ DenseVector medians = serializer.deserialize(inputViewStreamWrapper);
+ DenseVector ranges = serializer.deserialize(inputViewStreamWrapper);
+ return new RobustScalerModelData(medians, ranges);
+ } catch (EOFException e) {
+ return null;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ inputStream.close();
+ }
+ };
+ }
+
+ @Override
+ public TypeInformation<RobustScalerModelData> getProducedType() {
+ return TypeInformation.of(RobustScalerModelData.class);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelParams.java
new file mode 100644
index 0000000..60fdcb5
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerModelParams.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.robustscaler;
+
+import org.apache.flink.ml.common.param.HasInputCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.Param;
+
+/**
+ * Params for {@link RobustScalerModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface RobustScalerModelParams<T> extends HasInputCol<T>, HasOutputCol<T> {
+ Param<Boolean> WITH_CENTERING =
+ new BooleanParam(
+ "withCentering",
+ "Whether to center the data with median before scaling.",
+ false);
+
+ Param<Boolean> WITH_SCALING =
+ new BooleanParam("withScaling", "Whether to scale the data to quantile range.", true);
+
+ default boolean getWithCentering() {
+ return get(WITH_CENTERING);
+ }
+
+ default T setWithCentering(boolean value) {
+ return set(WITH_CENTERING, value);
+ }
+
+ default boolean getWithScaling() {
+ return get(WITH_SCALING);
+ }
+
+ default T setWithScaling(boolean value) {
+ return set(WITH_SCALING, value);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerParams.java
new file mode 100644
index 0000000..99fe326
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/robustscaler/RobustScalerParams.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.robustscaler;
+
+import org.apache.flink.ml.common.param.HasRelativeError;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params for {@link RobustScaler}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface RobustScalerParams<T> extends HasRelativeError<T>, RobustScalerModelParams<T> {
+ Param<Double> LOWER =
+ new DoubleParam(
+ "lower",
+ "Lower quantile to calculate quantile range.",
+ 0.25,
+ ParamValidators.inRange(0.0, 1.0, false, false));
+
+ Param<Double> UPPER =
+ new DoubleParam(
+ "upper",
+ "Upper quantile to calculate quantile range.",
+ 0.75,
+ ParamValidators.inRange(0.0, 1.0, false, false));
+
+ default double getLower() {
+ return get(LOWER);
+ }
+
+ default T setLower(Double value) {
+ return set(LOWER, value);
+ }
+
+ default double getUpper() {
+ return get(UPPER);
+ }
+
+ default T setUpper(Double value) {
+ return set(UPPER, value);
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ImputerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ImputerTest.java
index 29288ed..f788a39 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ImputerTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/ImputerTest.java
@@ -175,9 +175,16 @@ public class ImputerTest extends AbstractTestBase {
assertEquals(Double.NaN, imputer.getMissingValue(), EPS);
assertEquals(0.001, imputer.getRelativeError(), EPS);
- imputer.setMissingValue(0.0).setStrategy(MEDIAN);
+ imputer.setMissingValue(0.0)
+ .setStrategy(MEDIAN)
+ .setRelativeError(0.1)
+ .setInputCols("f1", "f2")
+ .setOutputCols("o1", "o2");
assertEquals(MEDIAN, imputer.getStrategy());
assertEquals(0.0, imputer.getMissingValue(), EPS);
+ assertEquals(0.1, imputer.getRelativeError(), EPS);
+ assertArrayEquals(new String[] {"f1", "f2"}, imputer.getInputCols());
+ assertArrayEquals(new String[] {"o1", "o2"}, imputer.getOutputCols());
}
@Test
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java
new file mode 100644
index 0000000..2b1e430
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RobustScalerTest.java
@@ -0,0 +1,345 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.robustscaler.RobustScaler;
+import org.apache.flink.ml.feature.robustscaler.RobustScalerModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Expressions;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/** Tests {@link RobustScaler} and {@link RobustScalerModel}. */
+public class RobustScalerTest extends AbstractTestBase {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainDataTable;
+ private Table predictDataTable;
+
+ private static final List<Row> TRAIN_DATA =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(0, Vectors.dense(0.0, 0.0)),
+ Row.of(1, Vectors.dense(1.0, -1.0)),
+ Row.of(2, Vectors.dense(2.0, -2.0)),
+ Row.of(3, Vectors.dense(3.0, -3.0)),
+ Row.of(4, Vectors.dense(4.0, -4.0)),
+ Row.of(5, Vectors.dense(5.0, -5.0)),
+ Row.of(6, Vectors.dense(6.0, -6.0)),
+ Row.of(7, Vectors.dense(7.0, -7.0)),
+ Row.of(8, Vectors.dense(8.0, -8.0))));
+ private static final List<Row> PREDICT_DATA =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(3.0, -3.0)),
+ Row.of(Vectors.dense(6.0, -6.0)),
+ Row.of(Vectors.dense(99.0, -99.0))));
+ private static final double EPS = 1.0e-5;
+
+ private static final List<DenseVector> EXPECTED_OUTPUT =
+ new ArrayList<>(
+ Arrays.asList(
+ Vectors.dense(0.75, -0.75),
+ Vectors.dense(1.5, -1.5),
+ Vectors.dense(24.75, -24.75)));
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+ trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("id", "input");
+ predictDataTable = tEnv.fromDataStream(env.fromCollection(PREDICT_DATA)).as("input");
+ }
+
+ private static void verifyPredictionResult(
+ Table output, String outputCol, List<DenseVector> expected) throws Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+ DataStream<DenseVector> stream =
+ tEnv.toDataStream(output)
+ .map(
+ (MapFunction<Row, DenseVector>)
+ row -> (DenseVector) row.getField(outputCol));
+ List<DenseVector> result = IteratorUtils.toList(stream.executeAndCollect());
+ compareResultCollections(expected, result, TestUtils::compare);
+ }
+
+ @Test
+ public void testParam() {
+ RobustScaler robustScaler = new RobustScaler();
+ assertEquals("input", robustScaler.getInputCol());
+ assertEquals("output", robustScaler.getOutputCol());
+ assertEquals(0.25, robustScaler.getLower(), EPS);
+ assertEquals(0.75, robustScaler.getUpper(), EPS);
+ assertEquals(0.001, robustScaler.getRelativeError(), EPS);
+ assertFalse(robustScaler.getWithCentering());
+ assertTrue(robustScaler.getWithScaling());
+
+ robustScaler
+ .setInputCol("test_input")
+ .setOutputCol("test_output")
+ .setLower(0.1)
+ .setUpper(0.9)
+ .setRelativeError(0.01)
+ .setWithCentering(true)
+ .setWithScaling(false);
+ assertEquals("test_input", robustScaler.getInputCol());
+ assertEquals("test_output", robustScaler.getOutputCol());
+ assertEquals(0.1, robustScaler.getLower(), EPS);
+ assertEquals(0.9, robustScaler.getUpper(), EPS);
+ assertEquals(0.01, robustScaler.getRelativeError(), EPS);
+ assertTrue(robustScaler.getWithCentering());
+ assertFalse(robustScaler.getWithScaling());
+ }
+
+ @Test
+ public void testOutputSchema() {
+ Table tempTable = trainDataTable.as("id", "test_input");
+ RobustScaler robustScaler =
+ new RobustScaler().setInputCol("test_input").setOutputCol("test_output");
+ RobustScalerModel model = robustScaler.fit(tempTable);
+ Table output = model.transform(tempTable)[0];
+ assertEquals(
+ Arrays.asList("id", "test_input", "test_output"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testFitAndPredict() throws Exception {
+ RobustScaler robustScaler = new RobustScaler();
+ RobustScalerModel model = robustScaler.fit(trainDataTable);
+ Table output = model.transform(predictDataTable)[0];
+ verifyPredictionResult(output, robustScaler.getOutputCol(), EXPECTED_OUTPUT);
+ }
+
+ @Test
+ public void testInputTypeConversion() throws Exception {
+ trainDataTable =
+ TestUtils.convertDataTypesToSparseInt(
+ tEnv, trainDataTable.select(Expressions.$("input")));
+ predictDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, predictDataTable);
+ assertArrayEquals(
+ new Class<?>[] {SparseVector.class}, TestUtils.getColumnDataTypes(trainDataTable));
+ assertArrayEquals(
+ new Class<?>[] {SparseVector.class},
+ TestUtils.getColumnDataTypes(predictDataTable));
+
+ RobustScaler robustScaler = new RobustScaler();
+ RobustScalerModel model = robustScaler.fit(trainDataTable);
+ Table output = model.transform(predictDataTable)[0];
+ verifyPredictionResult(output, robustScaler.getOutputCol(), EXPECTED_OUTPUT);
+ }
+
+ @Test
+ public void testSaveLoadAndPredict() throws Exception {
+ RobustScaler robustScaler = new RobustScaler();
+ RobustScaler loadedRobustScaler =
+ TestUtils.saveAndReload(
+ tEnv, robustScaler, tempFolder.newFolder().getAbsolutePath());
+ RobustScalerModel model = loadedRobustScaler.fit(trainDataTable);
+ RobustScalerModel loadedModel =
+ TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
+ assertEquals(
+ Arrays.asList("medians", "ranges"),
+ model.getModelData()[0].getResolvedSchema().getColumnNames());
+ Table output = loadedModel.transform(predictDataTable)[0];
+ verifyPredictionResult(output, robustScaler.getOutputCol(), EXPECTED_OUTPUT);
+ }
+
+ @Test
+ public void testFitOnEmptyData() {
+ Table emptyTable =
+ tEnv.fromDataStream(env.fromCollection(TRAIN_DATA).filter(x -> x.getArity() == 0))
+ .as("id", "input");
+ RobustScaler robustScaler = new RobustScaler();
+ RobustScalerModel model = robustScaler.fit(emptyTable);
+ Table modelDataTable = model.getModelData()[0];
+ try {
+ modelDataTable.execute().print();
+ fail();
+ } catch (Throwable e) {
+ assertEquals("The training set is empty.", ExceptionUtils.getRootCause(e).getMessage());
+ }
+ }
+
+ @Test
+ public void testWithCentering() throws Exception {
+ RobustScaler robustScaler = new RobustScaler().setWithCentering(true);
+ RobustScalerModel model = robustScaler.fit(trainDataTable);
+ Table output = model.transform(predictDataTable)[0];
+ List<DenseVector> expectedOutput =
+ new ArrayList<>(
+ Arrays.asList(
+ Vectors.dense(-0.25, 0.25),
+ Vectors.dense(0.5, -0.5),
+ Vectors.dense(23.75, -23.75)));
+ verifyPredictionResult(output, robustScaler.getOutputCol(), expectedOutput);
+ }
+
+ @Test
+ public void testWithoutScaling() throws Exception {
+ RobustScaler robustScaler = new RobustScaler().setWithCentering(true).setWithScaling(false);
+ RobustScalerModel model = robustScaler.fit(trainDataTable);
+ Table output = model.transform(predictDataTable)[0];
+ List<DenseVector> expectedOutput =
+ new ArrayList<>(
+ Arrays.asList(
+ Vectors.dense(-1, 1),
+ Vectors.dense(2, -2),
+ Vectors.dense(95, -95)));
+ verifyPredictionResult(output, robustScaler.getOutputCol(), expectedOutput);
+ }
+
+ @Test
+ public void testIncompatibleNumOfFeatures() {
+ RobustScaler robustScaler = new RobustScaler();
+ RobustScalerModel model = robustScaler.fit(trainDataTable);
+
+ List<Row> predictData =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(1.0, 2.0, 3.0)),
+ Row.of(Vectors.dense(-1.0, -2.0, -3.0))));
+ Table predictTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("input");
+ Table output = model.transform(predictTable)[0];
+ try {
+ output.execute().print();
+ fail();
+ } catch (Throwable e) {
+ assertTrue(
+ ExceptionUtils.getRootCause(e)
+ .getMessage()
+ .contains("Number of features must be"));
+ }
+ }
+
+ @Test
+ public void testZeroRange() throws Exception {
+ List<Row> trainData =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(0, Vectors.dense(0.0, 0.0)),
+ Row.of(1, Vectors.dense(1.0, 1.0)),
+ Row.of(2, Vectors.dense(1.0, 1.0)),
+ Row.of(3, Vectors.dense(1.0, 1.0)),
+ Row.of(4, Vectors.dense(4.0, 4.0))));
+ List<DenseVector> expectedOutput =
+ new ArrayList<>(
+ Arrays.asList(
+ Vectors.dense(0.0, -0.0),
+ Vectors.dense(0.0, -0.0),
+ Vectors.dense(0.0, -0.0)));
+ Table trainTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("id", "input");
+ RobustScaler robustScaler = new RobustScaler();
+ RobustScalerModel model = robustScaler.fit(trainTable);
+ Table output = model.transform(predictDataTable)[0];
+ verifyPredictionResult(output, robustScaler.getOutputCol(), expectedOutput);
+ }
+
+ @Test
+ public void testNaNData() throws Exception {
+ List<Row> trainData =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(0, Vectors.dense(0.0, Double.NaN)),
+ Row.of(1, Vectors.dense(Double.NaN, 0.0)),
+ Row.of(2, Vectors.dense(1.0, -1.0)),
+ Row.of(3, Vectors.dense(2.0, -2.0)),
+ Row.of(4, Vectors.dense(3.0, -3.0)),
+ Row.of(5, Vectors.dense(4.0, -4.0))));
+ List<DenseVector> expectedOutput =
+ new ArrayList<>(
+ Arrays.asList(
+ Vectors.dense(0.0, Double.NaN),
+ Vectors.dense(Double.NaN, 0.0),
+ Vectors.dense(0.5, -0.5),
+ Vectors.dense(1.0, -1.0),
+ Vectors.dense(1.5, -1.5),
+ Vectors.dense(2.0, -2.0)));
+ Table trainTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("id", "input");
+ RobustScaler robustScaler = new RobustScaler();
+ RobustScalerModel model = robustScaler.fit(trainTable);
+ Table output = model.transform(trainTable)[0];
+ verifyPredictionResult(output, robustScaler.getOutputCol(), expectedOutput);
+ }
+
+ @Test
+ public void testGetModelData() throws Exception {
+ RobustScaler robustScaler = new RobustScaler();
+ RobustScalerModel model = robustScaler.fit(trainDataTable);
+ Table modelData = model.getModelData()[0];
+ assertEquals(
+ Arrays.asList("medians", "ranges"), modelData.getResolvedSchema().getColumnNames());
+ DataStream<Row> output = tEnv.toDataStream(modelData);
+ List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect());
+ DenseVector medians = (DenseVector) modelRows.get(0).getField(0);
+ DenseVector ranges = (DenseVector) modelRows.get(0).getField(1);
+
+ DenseVector expectedMedians = Vectors.dense(4.0, -4.0);
+ DenseVector expectedRanges = Vectors.dense(4.0, 4.0);
+ assertEquals(expectedMedians, medians);
+ assertEquals(expectedRanges, ranges);
+ }
+
+ @Test
+ public void testSetModelData() throws Exception {
+ RobustScaler robustScaler = new RobustScaler();
+ RobustScalerModel modelA = robustScaler.fit(trainDataTable);
+
+ Table modelData = modelA.getModelData()[0];
+ RobustScalerModel modelB = new RobustScalerModel().setModelData(modelData);
+ Table output = modelB.transform(predictDataTable)[0];
+ verifyPredictionResult(output, robustScaler.getOutputCol(), EXPECTED_OUTPUT);
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java
index 1230780..217e061 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java
@@ -128,12 +128,15 @@ public class VarianceThresholdSelectorTest extends AbstractTestBase {
@Test
public void testOutputSchema() {
+ Table tempTable = trainDataTable.as("id", "test_input");
VarianceThresholdSelector varianceThresholdSelector =
- new VarianceThresholdSelector().setVarianceThreshold(0.5);
- VarianceThresholdSelectorModel model = varianceThresholdSelector.fit(trainDataTable);
- Table output = model.transform(trainDataTable)[0];
+ new VarianceThresholdSelector()
+ .setInputCol("test_input")
+ .setOutputCol("test_output");
+ VarianceThresholdSelectorModel model = varianceThresholdSelector.fit(tempTable);
+ Table output = model.transform(tempTable)[0];
assertEquals(
- Arrays.asList("id", "input", "output"),
+ Arrays.asList("id", "test_input", "test_output"),
output.getResolvedSchema().getColumnNames());
}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/robustscaler_example.py b/flink-ml-python/pyflink/examples/ml/feature/robustscaler_example.py
new file mode 100644
index 0000000..51385b8
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/robustscaler_example.py
@@ -0,0 +1,74 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates a RobustScaler instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.table import StreamTableEnvironment
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+
+from pyflink.ml.lib.feature.robustscaler import RobustScaler
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input training and prediction data.
+train_data = t_env.from_data_stream(
+ env.from_collection([
+ (1, Vectors.dense(0.0, 0.0),),
+ (2, Vectors.dense(1.0, -1.0),),
+ (3, Vectors.dense(2.0, -2.0),),
+ (4, Vectors.dense(3.0, -3.0),),
+ (5, Vectors.dense(4.0, -4.0),),
+ (6, Vectors.dense(5.0, -5.0),),
+ (7, Vectors.dense(6.0, -6.0),),
+ (8, Vectors.dense(7.0, -7.0),),
+ (9, Vectors.dense(8.0, -8.0),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['id', 'input'],
+ [Types.INT(), DenseVectorTypeInfo()])
+ ))
+
+# Creates an RobustScaler object and initializes its parameters.
+robust_scaler = RobustScaler()\
+ .set_lower(0.25)\
+ .set_upper(0.75)\
+ .set_relative_error(0.001)\
+ .set_with_scaling(True)\
+ .set_with_centering(True)
+
+# Trains the RobustScaler Model.
+model = robust_scaler.fit(train_data)
+
+# Uses the RobustScaler Model for predictions.
+output = model.transform(train_data)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ input_index = field_names.index(robust_scaler.get_input_col())
+ output_index = field_names.index(robust_scaler.get_output_col())
+ print('Input Value: ' + str(result[input_index]) +
+ '\tOutput Value: ' + str(result[output_index]))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/robustscaler.py b/flink-ml-python/pyflink/ml/lib/feature/robustscaler.py
new file mode 100644
index 0000000..b420edd
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/robustscaler.py
@@ -0,0 +1,163 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.param import HasInputCol, HasOutputCol, HasRelativeError
+from pyflink.ml.core.param import BooleanParam, Param, FloatParam, ParamValidators
+
+from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator
+
+
+class _RobustScalerModelParams(
+ JavaWithParams,
+ HasInputCol,
+ HasOutputCol
+):
+ """
+ Params for :class `RobustScalerModel`.
+ """
+ WITH_CENTERING: Param[bool] = BooleanParam(
+ "with_centering",
+ "Whether to center the data with median before scaling.",
+ False
+ )
+
+ WITH_SCALING: Param[bool] = BooleanParam(
+ "with_scaling",
+ "Whether to scale the data to quantile range.",
+ True
+ )
+
+ def __init__(self, java_params):
+ super(_RobustScalerModelParams, self).__init__(java_params)
+
+ def set_with_centering(self, value: bool):
+ return typing.cast(_RobustScalerModelParams, self.set(self.WITH_CENTERING, value))
+
+ def get_with_centering(self):
+ return self.get(self.WITH_CENTERING)
+
+ def set_with_scaling(self, value: bool):
+ return typing.cast(_RobustScalerModelParams, self.set(self.WITH_SCALING, value))
+
+ def get_with_scaling(self):
+ return self.get(self.WITH_SCALING)
+
+ @property
+ def with_centering(self):
+ return self.get_with_centering()
+
+ @property
+ def with_scaling(self):
+ return self.get_with_scaling()
+
+
+class _RobustScalerParams(HasRelativeError, _RobustScalerModelParams):
+ """
+ Params for :class `RobustScaler`.
+ """
+ LOWER: Param[float] = FloatParam(
+ "lower",
+ "Lower quantile to calculate quantile range.",
+ 0.25,
+ ParamValidators.in_range(0.0, 1.0, False, False)
+ )
+
+ UPPER: Param[float] = FloatParam(
+ "upper",
+ "Upper quantile to calculate quantile range.",
+ 0.75,
+ ParamValidators.in_range(0.0, 1.0, False, False)
+ )
+
+ def __init__(self, java_params):
+ super(_RobustScalerParams, self).__init__(java_params)
+
+ def set_lower(self, value: float):
+ return typing.cast(_RobustScalerParams, self.set(self.LOWER, value))
+
+ def get_lower(self):
+ return self.get(self.LOWER)
+
+ def set_upper(self, value: float):
+ return typing.cast(_RobustScalerParams, self.set(self.UPPER, value))
+
+ def get_upper(self):
+ return self.get(self.UPPER)
+
+ @property
+ def lower(self):
+ return self.get_lower()
+
+ @property
+ def upper(self):
+ return self.get_upper()
+
+
+class RobustScalerModel(JavaFeatureModel, _RobustScalerModelParams):
+ """
+ A Model which transforms data using the model data computed by :class::RobustScaler.
+ """
+
+ def __init__(self, java_model=None):
+ super(RobustScalerModel, self).__init__(java_model)
+
+ @classmethod
+ def _java_model_package_name(cls) -> str:
+ return "robustscaler"
+
+ @classmethod
+ def _java_model_class_name(cls) -> str:
+ return "RobustScalerModel"
+
+
+class RobustScaler(JavaFeatureEstimator, _RobustScalerParams):
+ """
+ An Estimator which scales features using statistics that are robust to outliers.
+
+ This Scaler removes the median and scales the data according to the quantile
+ range (defaults to IQR: Interquartile Range). The IQR is the range between the 1st
+ quartile (25th quantile) and the 3rd quartile (75th quantile) but can be configured.
+
+ Centering and scaling happen independently on each feature by computing the relevant
+ statistics on the samples in the training set. Median and quantile range are then
+ stored to be used on later data using the transform method.
+
+ Standardization of a dataset is a common requirement for many machine learning estimators.
+ Typically this is done by removing the mean and scaling to unit variance. However, outliers can
+ often influence the sample mean / variance in a negative way. In such cases, the median and the
+ interquartile range often give better results.
+
+ Note that NaN values are ignored in the computation of medians and ranges.
+ """
+
+ def __init__(self):
+ super(RobustScaler, self).__init__()
+
+ @classmethod
+ def _create_model(cls, java_model) -> RobustScalerModel:
+ return RobustScalerModel(java_model)
+
+ @classmethod
+ def _java_estimator_package_name(cls) -> str:
+ return "robustscaler"
+
+ @classmethod
+ def _java_estimator_class_name(cls) -> str:
+ return "RobustScaler"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_robustscaler.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_robustscaler.py
new file mode 100644
index 0000000..d423aee
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_robustscaler.py
@@ -0,0 +1,135 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from typing import List
+
+from pyflink.common import Types
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo, DenseVector
+
+from pyflink.ml.lib.feature.robustscaler import RobustScaler
+from pyflink.table import Table
+
+
+class RobustScalerTest(PyFlinkMLTestCase):
+
+ def setUp(self):
+ super(RobustScalerTest, self).setUp()
+ self.train_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (1, Vectors.dense(0.0, 0.0),),
+ (2, Vectors.dense(1.0, -1.0),),
+ (3, Vectors.dense(2.0, -2.0),),
+ (4, Vectors.dense(3.0, -3.0),),
+ (5, Vectors.dense(4.0, -4.0),),
+ (6, Vectors.dense(5.0, -5.0),),
+ (7, Vectors.dense(6.0, -6.0),),
+ (8, Vectors.dense(7.0, -7.0),),
+ (9, Vectors.dense(8.0, -8.0),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['id', 'input'],
+ [Types.INT(), DenseVectorTypeInfo()])
+ ))
+
+ self.predict_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (Vectors.dense(3.0, -3.0),),
+ (Vectors.dense(6.0, -6.0),),
+ (Vectors.dense(99.0, -99.0),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input'],
+ [DenseVectorTypeInfo()])
+ ))
+
+ self.expected_output = [
+ Vectors.dense(0.75, -0.75),
+ Vectors.dense(1.5, -1.5),
+ Vectors.dense(24.75, -24.75)]
+
+ def test_param(self):
+ robust_scaler = RobustScaler()
+ self.assertEqual("input", robust_scaler.input_col)
+ self.assertEqual("output", robust_scaler.output_col)
+ self.assertEqual(0.25, robust_scaler.lower)
+ self.assertEqual(0.75, robust_scaler.upper)
+ self.assertEqual(0.001, robust_scaler.relative_error)
+ self.assertFalse(robust_scaler.with_centering)
+ self.assertTrue(robust_scaler.with_scaling)
+
+ robust_scaler\
+ .set_input_col("test_input")\
+ .set_output_col("test_output")\
+ .set_lower(0.1)\
+ .set_upper(0.9)\
+ .set_relative_error(0.01)\
+ .set_with_centering(True)\
+ .set_with_scaling(False)
+
+ self.assertEqual("test_input", robust_scaler.input_col)
+ self.assertEqual("test_output", robust_scaler.output_col)
+ self.assertEqual(0.1, robust_scaler.lower)
+ self.assertEqual(0.9, robust_scaler.upper)
+ self.assertEqual(0.01, robust_scaler.relative_error)
+ self.assertTrue(robust_scaler.with_centering)
+ self.assertFalse(robust_scaler.with_scaling)
+
+ def test_output_schema(self):
+ robust_scaler = RobustScaler().set_output_col('test_output')
+ model = robust_scaler.fit(self.train_table)
+ output = model.transform(self.predict_table.alias('test_input'))[0]
+ self.assertEqual(
+ ['test_input', 'test_output'],
+ output.get_schema().get_field_names())
+
+ def test_fit_and_predict(self):
+ robust_scaler = RobustScaler()
+ model = robust_scaler.fit(self.train_table)
+ output = model.transform(self.predict_table)[0]
+ self.verify_output_result(
+ output,
+ robust_scaler.get_output_col(),
+ output.get_schema().get_field_names(),
+ self.expected_output)
+
+ def test_save_load_predict(self):
+ robust_scaler = RobustScaler()
+ reloaded_robust_scaler = self.save_and_reload(robust_scaler)
+ model = reloaded_robust_scaler.fit(self.train_table)
+ reloaded_model = self.save_and_reload(model)
+ output = reloaded_model.transform(self.predict_table)[0]
+ self.verify_output_result(
+ output,
+ robust_scaler.get_output_col(),
+ output.get_schema().get_field_names(),
+ self.expected_output)
+
+ def verify_output_result(
+ self, output: Table,
+ output_col: str,
+ field_names: List[str],
+ expected_result: List[DenseVector]):
+ collected_results = [result for result in
+ self.t_env.to_data_stream(output).execute_and_collect()]
+ results = []
+ for item in collected_results:
+ item.set_field_names(field_names)
+ results.append(item[output_col])
+ results.sort(key=lambda x: x[0])
+ self.assertEqual(expected_result, results)