You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/03/22 09:54:49 UTC
[flink-ml] branch master updated: [FLINK-25552] Add Estimator and Transformer for MinMaxScaler
This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 3ab3273 [FLINK-25552] Add Estimator and Transformer for MinMaxScaler
3ab3273 is described below
commit 3ab327394769d5bd4f88be04428d33320b0928a3
Author: weibo <we...@alibaba-inc.com>
AuthorDate: Mon Mar 21 19:33:03 2022 +0800
[FLINK-25552] Add Estimator and Transformer for MinMaxScaler
This closes #54.
---
.../flink/ml/classification/knn/KnnModelData.java | 10 +-
.../ml/feature/minmaxscaler/MinMaxScaler.java | 203 +++++++++++++++++++
.../ml/feature/minmaxscaler/MinMaxScalerModel.java | 183 +++++++++++++++++
.../minmaxscaler/MinMaxScalerModelData.java} | 69 +++----
.../feature/minmaxscaler/MinMaxScalerParams.java | 62 ++++++
.../apache/flink/ml/feature/MinMaxScalerTest.java | 218 +++++++++++++++++++++
6 files changed, 699 insertions(+), 46 deletions(-)
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
index 89051e6..4bf0adb 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
@@ -82,13 +82,11 @@ public class KnnModelData {
/** Encoder for {@link KnnModelData}. */
public static class ModelDataEncoder implements Encoder<KnnModelData> {
@Override
- public void encode(KnnModelData knnModelData, OutputStream outputStream)
- throws IOException {
+ public void encode(KnnModelData modelData, OutputStream outputStream) throws IOException {
DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream);
- DenseMatrixSerializer.INSTANCE.serialize(knnModelData.packedFeatures, dataOutputView);
- DenseVectorSerializer.INSTANCE.serialize(
- knnModelData.featureNormSquares, dataOutputView);
- DenseVectorSerializer.INSTANCE.serialize(knnModelData.labels, dataOutputView);
+ DenseMatrixSerializer.INSTANCE.serialize(modelData.packedFeatures, dataOutputView);
+ DenseVectorSerializer.INSTANCE.serialize(modelData.featureNormSquares, dataOutputView);
+ DenseVectorSerializer.INSTANCE.serialize(modelData.labels, dataOutputView);
}
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
new file mode 100644
index 0000000..19a9f6f
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.minmaxscaler;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the MinMaxScaler algorithm. This algorithm rescales feature values
+ * to a common range [min, max] which defined by user.
+ *
+ * <blockquote>
+ *
+ * $$ Rescaled(value) = \frac{value - E_{min}}{E_{max} - E_{min}} * (max - min) + min $$
+ *
+ * </blockquote>
+ *
+ * <p>For the case \(E_{max} == E_{min}\), \(Rescaled(value) = 0.5 * (max + min)\).
+ *
+ * <p>See https://en.wikipedia.org/wiki/Feature_scaling#Rescaling_(min-max_normalization).
+ */
+public class MinMaxScaler
+ implements Estimator<MinMaxScaler, MinMaxScalerModel>, MinMaxScalerParams<MinMaxScaler> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public MinMaxScaler() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public MinMaxScalerModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ final String featureCol = getFeaturesCol();
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+ DataStream<DenseVector> features =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ (MapFunction<Row, DenseVector>)
+ value -> (DenseVector) value.getField(featureCol));
+ DataStream<DenseVector> minMaxValues =
+ features.transform(
+ "reduceInEachPartition",
+ features.getType(),
+ new MinMaxReduceFunctionOperator())
+ .transform(
+ "reduceInFinalPartition",
+ features.getType(),
+ new MinMaxReduceFunctionOperator())
+ .setParallelism(1);
+ DataStream<MinMaxScalerModelData> modelData =
+ DataStreamUtils.mapPartition(
+ minMaxValues,
+ new RichMapPartitionFunction<DenseVector, MinMaxScalerModelData>() {
+ @Override
+ public void mapPartition(
+ Iterable<DenseVector> values,
+ Collector<MinMaxScalerModelData> out) {
+ Iterator<DenseVector> iter = values.iterator();
+ DenseVector minVector = iter.next();
+ DenseVector maxVector = iter.next();
+ out.collect(new MinMaxScalerModelData(minVector, maxVector));
+ }
+ });
+
+ MinMaxScalerModel model =
+ new MinMaxScalerModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ /**
+ * A stream operator to compute the min and max values in each partition of the input bounded
+ * data stream.
+ */
+ private static class MinMaxReduceFunctionOperator extends AbstractStreamOperator<DenseVector>
+ implements OneInputStreamOperator<DenseVector, DenseVector>, BoundedOneInput {
+ private ListState<DenseVector> minState;
+ private ListState<DenseVector> maxState;
+
+ private DenseVector minVector;
+ private DenseVector maxVector;
+
+ @Override
+ public void endInput() {
+ if (minVector != null) {
+ output.collect(new StreamRecord<>(minVector));
+ output.collect(new StreamRecord<>(maxVector));
+ }
+ }
+
+ @Override
+ public void processElement(StreamRecord<DenseVector> streamRecord) {
+ DenseVector currentValue = streamRecord.getValue();
+ if (minVector == null) {
+ int vecSize = currentValue.size();
+ minVector = new DenseVector(vecSize);
+ maxVector = new DenseVector(vecSize);
+ System.arraycopy(currentValue.values, 0, minVector.values, 0, vecSize);
+ System.arraycopy(currentValue.values, 0, maxVector.values, 0, vecSize);
+ } else {
+ Preconditions.checkArgument(
+ currentValue.size() == maxVector.size(),
+ "CurrentValue should has same size with maxVector.");
+ for (int i = 0; i < currentValue.size(); ++i) {
+ minVector.values[i] = Math.min(minVector.values[i], currentValue.values[i]);
+ maxVector.values[i] = Math.max(maxVector.values[i], currentValue.values[i]);
+ }
+ }
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void initializeState(StateInitializationContext context) throws Exception {
+ super.initializeState(context);
+ minState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "minState", TypeInformation.of(DenseVector.class)));
+ maxState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "maxState", TypeInformation.of(DenseVector.class)));
+
+ OperatorStateUtils.getUniqueElement(minState, "minState").ifPresent(x -> minVector = x);
+ OperatorStateUtils.getUniqueElement(maxState, "maxState").ifPresent(x -> maxVector = x);
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void snapshotState(StateSnapshotContext context) throws Exception {
+ super.snapshotState(context);
+ minState.clear();
+ maxState.clear();
+ if (minVector != null) {
+ minState.add(minVector);
+ maxState.add(maxVector);
+ }
+ }
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static MinMaxScaler load(StreamExecutionEnvironment env, String path)
+ throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
new file mode 100644
index 0000000..762d74a
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.minmaxscaler;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which do a minMax scaler operation using the model data computed by {@link MinMaxScaler}.
+ */
+public class MinMaxScalerModel
+ implements Model<MinMaxScalerModel>, MinMaxScalerParams<MinMaxScalerModel> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private Table modelDataTable;
+
+ public MinMaxScalerModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public MinMaxScalerModel setModelData(Table... inputs) {
+ modelDataTable = inputs[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+ DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+ DataStream<MinMaxScalerModelData> minMaxScalerModel =
+ MinMaxScalerModelData.getModelDataStream(modelDataTable);
+ final String broadcastModelKey = "broadcastModelKey";
+ RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(
+ inputTypeInfo.getFieldTypes(),
+ TypeInformation.of(DenseVector.class)),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+ DataStream<Row> output =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(data),
+ Collections.singletonMap(broadcastModelKey, minMaxScalerModel),
+ inputList -> {
+ DataStream input = inputList.get(0);
+ return input.map(
+ new PredictOutputFunction(
+ broadcastModelKey,
+ getMax(),
+ getMin(),
+ getFeaturesCol()),
+ outputTypeInfo);
+ });
+ return new Table[] {tEnv.fromDataStream(output)};
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ MinMaxScalerModelData.getModelDataStream(modelDataTable),
+ path,
+ new MinMaxScalerModelData.ModelDataEncoder());
+ }
+
+ /**
+ * Loads model data from path.
+ *
+ * @param env Stream execution environment.
+ * @param path Model path.
+ * @return MinMaxScalerModel model.
+ */
+ public static MinMaxScalerModel load(StreamExecutionEnvironment env, String path)
+ throws IOException {
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ MinMaxScalerModel model = ReadWriteUtils.loadStageParam(path);
+ DataStream<MinMaxScalerModelData> modelData =
+ ReadWriteUtils.loadModelData(
+ env, path, new MinMaxScalerModelData.ModelDataDecoder());
+ return model.setModelData(tEnv.fromDataStream(modelData));
+ }
+
+ /** This operator loads model data and predicts result. */
+ private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+ private final String featureCol;
+ private final String broadcastKey;
+ private final double upperBound;
+ private final double lowerBound;
+ private DenseVector scaleVector;
+ private DenseVector offsetVector;
+
+ public PredictOutputFunction(
+ String broadcastKey, double upperBound, double lowerBound, String featureCol) {
+ this.upperBound = upperBound;
+ this.lowerBound = lowerBound;
+ this.broadcastKey = broadcastKey;
+ this.featureCol = featureCol;
+ }
+
+ @Override
+ public Row map(Row row) {
+ if (scaleVector == null) {
+ MinMaxScalerModelData minMaxScalerModelData =
+ (MinMaxScalerModelData)
+ getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+ DenseVector minVector = minMaxScalerModelData.minVector;
+ DenseVector maxVector = minMaxScalerModelData.maxVector;
+ scaleVector = new DenseVector(minVector.size());
+ offsetVector = new DenseVector(minVector.size());
+ for (int i = 0; i < maxVector.size(); ++i) {
+ if (Math.abs(minVector.values[i] - maxVector.values[i]) < 1.0e-5) {
+ scaleVector.values[i] = 0.0;
+ offsetVector.values[i] = (upperBound + lowerBound) / 2;
+ } else {
+ scaleVector.values[i] =
+ (upperBound - lowerBound)
+ / (maxVector.values[i] - minVector.values[i]);
+ offsetVector.values[i] =
+ lowerBound - minVector.values[i] * scaleVector.values[i];
+ }
+ }
+ }
+ DenseVector feature = (DenseVector) row.getField(featureCol);
+ DenseVector outputVector = new DenseVector(scaleVector.size());
+ for (int i = 0; i < scaleVector.size(); ++i) {
+ outputVector.values[i] =
+ feature.values[i] * scaleVector.values[i] + offsetVector.values[i];
+ }
+ return Row.join(row, Row.of(outputVector));
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java
similarity index 56%
copy from flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
copy to flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java
index 89051e6..301eadd 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java
@@ -16,7 +16,7 @@
* limitations under the License.
*/
-package org.apache.flink.ml.classification.knn;
+package org.apache.flink.ml.feature.minmaxscaler;
import org.apache.flink.api.common.serialization.Encoder;
import org.apache.flink.api.common.typeinfo.TypeInformation;
@@ -27,9 +27,7 @@ import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
-import org.apache.flink.ml.linalg.DenseMatrix;
import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer;
import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.Table;
@@ -41,24 +39,21 @@ import java.io.IOException;
import java.io.OutputStream;
/**
- * Model data of {@link KnnModel}.
+ * Model data of {@link MinMaxScalerModel}.
*
* <p>This class also provides methods to convert model data from Table to a data stream, and
* classes to save/load model data.
*/
-public class KnnModelData {
+public class MinMaxScalerModelData {
+ public DenseVector minVector;
- public DenseMatrix packedFeatures;
- public DenseVector featureNormSquares;
- public DenseVector labels;
+ public DenseVector maxVector;
- public KnnModelData() {}
+ public MinMaxScalerModelData() {}
- public KnnModelData(
- DenseMatrix packedFeatures, DenseVector featureNormSquares, DenseVector labels) {
- this.packedFeatures = packedFeatures;
- this.featureNormSquares = featureNormSquares;
- this.labels = labels;
+ public MinMaxScalerModelData(DenseVector minVector, DenseVector maxVector) {
+ this.minVector = minVector;
+ this.maxVector = maxVector;
}
/**
@@ -67,47 +62,41 @@ public class KnnModelData {
* @param modelDataTable The table model data.
* @return The data stream model data.
*/
- public static DataStream<KnnModelData> getModelDataStream(Table modelDataTable) {
+ public static DataStream<MinMaxScalerModelData> getModelDataStream(Table modelDataTable) {
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
return tEnv.toDataStream(modelDataTable)
.map(
x ->
- new KnnModelData(
- (DenseMatrix) x.getField(0),
- (DenseVector) x.getField(1),
- (DenseVector) x.getField(2)));
+ new MinMaxScalerModelData(
+ (DenseVector) x.getField(0), (DenseVector) x.getField(1)));
}
- /** Encoder for {@link KnnModelData}. */
- public static class ModelDataEncoder implements Encoder<KnnModelData> {
+ /** Encoder for {@link MinMaxScalerModelData}. */
+ public static class ModelDataEncoder implements Encoder<MinMaxScalerModelData> {
@Override
- public void encode(KnnModelData knnModelData, OutputStream outputStream)
+ public void encode(MinMaxScalerModelData modelData, OutputStream outputStream)
throws IOException {
DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream);
- DenseMatrixSerializer.INSTANCE.serialize(knnModelData.packedFeatures, dataOutputView);
- DenseVectorSerializer.INSTANCE.serialize(
- knnModelData.featureNormSquares, dataOutputView);
- DenseVectorSerializer.INSTANCE.serialize(knnModelData.labels, dataOutputView);
+ DenseVectorSerializer.INSTANCE.serialize(modelData.minVector, dataOutputView);
+ DenseVectorSerializer.INSTANCE.serialize(modelData.maxVector, dataOutputView);
}
}
- /** Decoder for {@link KnnModelData}. */
- public static class ModelDataDecoder extends SimpleStreamFormat<KnnModelData> {
+ /** Decoder for {@link MinMaxScalerModelData}. */
+ public static class ModelDataDecoder extends SimpleStreamFormat<MinMaxScalerModelData> {
@Override
- public Reader<KnnModelData> createReader(Configuration config, FSDataInputStream stream) {
- return new Reader<KnnModelData>() {
-
- private final DataInputView source = new DataInputViewStreamWrapper(stream);
+ public Reader<MinMaxScalerModelData> createReader(
+ Configuration config, FSDataInputStream stream) {
+ return new Reader<MinMaxScalerModelData>() {
@Override
- public KnnModelData read() throws IOException {
+ public MinMaxScalerModelData read() throws IOException {
+ DataInputView source = new DataInputViewStreamWrapper(stream);
try {
- DenseMatrix matrix = DenseMatrixSerializer.INSTANCE.deserialize(source);
- DenseVector normSquares =
- DenseVectorSerializer.INSTANCE.deserialize(source);
- DenseVector labels = DenseVectorSerializer.INSTANCE.deserialize(source);
- return new KnnModelData(matrix, normSquares, labels);
+ DenseVector minVector = DenseVectorSerializer.INSTANCE.deserialize(source);
+ DenseVector maxVector = DenseVectorSerializer.INSTANCE.deserialize(source);
+ return new MinMaxScalerModelData(minVector, maxVector);
} catch (EOFException e) {
return null;
}
@@ -121,8 +110,8 @@ public class KnnModelData {
}
@Override
- public TypeInformation<KnnModelData> getProducedType() {
- return TypeInformation.of(KnnModelData.class);
+ public TypeInformation<MinMaxScalerModelData> getProducedType() {
+ return TypeInformation.of(MinMaxScalerModelData.class);
}
}
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerParams.java
new file mode 100644
index 0000000..aade500
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerParams.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.minmaxscaler;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params for {@link MinMaxScaler}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface MinMaxScalerParams<T> extends HasFeaturesCol<T>, HasPredictionCol<T> {
+ Param<Double> MIN =
+ new DoubleParam(
+ "min",
+ "Lower bound of the output feature range.",
+ 0.0,
+ ParamValidators.notNull());
+
+ default Double getMin() {
+ return get(MIN);
+ }
+
+ default T setMin(Double value) {
+ return set(MIN, value);
+ }
+
+ Param<Double> MAX =
+ new DoubleParam(
+ "max",
+ "Upper bound of the output feature range.",
+ 1.0,
+ ParamValidators.notNull());
+
+ default Double getMax() {
+ return get(MAX);
+ }
+
+ default T setMax(Double value) {
+ return set(MAX, value);
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
new file mode 100644
index 0000000..24ec885
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler;
+import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link MinMaxScaler} and {@link MinMaxScalerModel}. */
+public class MinMaxScalerTest {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainDataTable;
+ private Table predictDataTable;
+ private static final List<Row> TRAIN_DATA =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(0.0, 3.0)),
+ Row.of(Vectors.dense(2.1, 0.0)),
+ Row.of(Vectors.dense(4.1, 5.1)),
+ Row.of(Vectors.dense(6.1, 8.1)),
+ Row.of(Vectors.dense(200, 400))));
+ private static final List<Row> PREDICT_DATA =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(150.0, 90.0)),
+ Row.of(Vectors.dense(50.0, 40.0)),
+ Row.of(Vectors.dense(100.0, 50.0))));
+ private static final double EPS = 1.0e-5;
+ private static final List<DenseVector> EXPECTED_DATA =
+ new ArrayList<>(
+ Arrays.asList(
+ Vectors.dense(0.25, 0.1),
+ Vectors.dense(0.5, 0.125),
+ Vectors.dense(0.75, 0.225)));
+
+ /** Note: this comparator imposes orderings that are inconsistent with equals. */
+ private static int compare(DenseVector first, DenseVector second) {
+ for (int i = 0; i < first.size(); i++) {
+ int cmp = Double.compare(first.get(i), second.get(i));
+ if (cmp != 0) {
+ return cmp;
+ }
+ }
+ return 0;
+ }
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+ trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("features");
+ predictDataTable = tEnv.fromDataStream(env.fromCollection(PREDICT_DATA)).as("features");
+ }
+
+ private static void verifyPredictionResult(
+ Table output, String outputCol, List<DenseVector> expected) throws Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+ DataStream<DenseVector> stream =
+ tEnv.toDataStream(output)
+ .map(
+ (MapFunction<Row, DenseVector>)
+ row -> (DenseVector) row.getField(outputCol));
+ List<DenseVector> result = IteratorUtils.toList(stream.executeAndCollect());
+ result.sort(MinMaxScalerTest::compare);
+ assertEquals(expected, result);
+ }
+
+ @Test
+ public void testParam() {
+ MinMaxScaler minMaxScaler = new MinMaxScaler();
+ assertEquals("features", minMaxScaler.getFeaturesCol());
+ assertEquals("prediction", minMaxScaler.getPredictionCol());
+ assertEquals(0.0, minMaxScaler.getMin(), EPS);
+ assertEquals(1.0, minMaxScaler.getMax(), EPS);
+ minMaxScaler
+ .setFeaturesCol("test_features")
+ .setPredictionCol("test_output")
+ .setMin(1.0)
+ .setMax(4.0);
+ assertEquals("test_features", minMaxScaler.getFeaturesCol());
+ assertEquals(1.0, minMaxScaler.getMin(), EPS);
+ assertEquals(4.0, minMaxScaler.getMax(), EPS);
+ assertEquals("test_output", minMaxScaler.getPredictionCol());
+ }
+
+ @Test
+ public void testFeaturePredictionParam() {
+ MinMaxScaler minMaxScaler =
+ new MinMaxScaler()
+ .setFeaturesCol("test_features")
+ .setPredictionCol("test_output")
+ .setMin(1.0)
+ .setMax(4.0);
+
+ MinMaxScalerModel model = minMaxScaler.fit(trainDataTable.as("test_features"));
+ Table output = model.transform(predictDataTable.as("test_features"))[0];
+ assertEquals(
+ Arrays.asList("test_features", "test_output"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testMaxValueEqualsMinValueButPredictValueNotEquals() throws Exception {
+ List<Row> trainData =
+ new ArrayList<>(Collections.singletonList(Row.of(Vectors.dense(40.0, 80.0))));
+ Table trainTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("features");
+ List<Row> predictData =
+ new ArrayList<>(Collections.singletonList(Row.of(Vectors.dense(30.0, 50.0))));
+ Table predictDataTable =
+ tEnv.fromDataStream(env.fromCollection(predictData)).as("features");
+ MinMaxScaler minMaxScaler = new MinMaxScaler().setMax(10.0).setMin(0.0);
+ MinMaxScalerModel model = minMaxScaler.fit(trainTable);
+ Table result = model.transform(predictDataTable)[0];
+ verifyPredictionResult(
+ result,
+ minMaxScaler.getPredictionCol(),
+ Collections.singletonList(Vectors.dense(5.0, 5.0)));
+ }
+
+ @Test
+ public void testFitAndPredict() throws Exception {
+ MinMaxScaler minMaxScaler = new MinMaxScaler();
+ MinMaxScalerModel minMaxScalerModel = minMaxScaler.fit(trainDataTable);
+ Table output = minMaxScalerModel.transform(predictDataTable)[0];
+ verifyPredictionResult(output, minMaxScaler.getPredictionCol(), EXPECTED_DATA);
+ }
+
+ @Test
+ public void testSaveLoadAndPredict() throws Exception {
+ MinMaxScaler minMaxScaler = new MinMaxScaler();
+ MinMaxScaler loadedMinMaxScaler =
+ StageTestUtils.saveAndReload(
+ env, minMaxScaler, tempFolder.newFolder().getAbsolutePath());
+ MinMaxScalerModel model = loadedMinMaxScaler.fit(trainDataTable);
+ MinMaxScalerModel loadedModel =
+ StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
+ assertEquals(
+ Arrays.asList("minVector", "maxVector"),
+ model.getModelData()[0].getResolvedSchema().getColumnNames());
+ Table output = loadedModel.transform(predictDataTable)[0];
+ verifyPredictionResult(output, minMaxScaler.getPredictionCol(), EXPECTED_DATA);
+ }
+
+ @Test
+ public void testGetModelData() throws Exception {
+ MinMaxScaler minMaxScaler = new MinMaxScaler();
+ MinMaxScalerModel minMaxScalerModel = minMaxScaler.fit(trainDataTable);
+ Table modelData = minMaxScalerModel.getModelData()[0];
+ assertEquals(
+ Arrays.asList("minVector", "maxVector"),
+ modelData.getResolvedSchema().getColumnNames());
+ DataStream<Row> output = tEnv.toDataStream(modelData);
+ List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect());
+ assertEquals(new DenseVector(new double[] {0.0, 0.0}), modelRows.get(0).getField(0));
+ assertEquals(new DenseVector(new double[] {200.0, 400.0}), modelRows.get(0).getField(1));
+ }
+
+ @Test
+ public void testSetModelData() throws Exception {
+ MinMaxScaler minMaxScaler = new MinMaxScaler();
+ MinMaxScalerModel modelA = minMaxScaler.fit(trainDataTable);
+ Table modelData = modelA.getModelData()[0];
+ MinMaxScalerModel modelB = new MinMaxScalerModel().setModelData(modelData);
+ ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+ Table output = modelB.transform(predictDataTable)[0];
+ verifyPredictionResult(output, minMaxScaler.getPredictionCol(), EXPECTED_DATA);
+ }
+}