You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/03/22 09:54:49 UTC

[flink-ml] branch master updated: [FLINK-25552] Add Estimator and Transformer for MinMaxScaler

This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 3ab3273  [FLINK-25552] Add Estimator and Transformer for MinMaxScaler
3ab3273 is described below

commit 3ab327394769d5bd4f88be04428d33320b0928a3
Author: weibo <we...@alibaba-inc.com>
AuthorDate: Mon Mar 21 19:33:03 2022 +0800

    [FLINK-25552] Add Estimator and Transformer for MinMaxScaler
    
    This closes #54.
---
 .../flink/ml/classification/knn/KnnModelData.java  |  10 +-
 .../ml/feature/minmaxscaler/MinMaxScaler.java      | 203 +++++++++++++++++++
 .../ml/feature/minmaxscaler/MinMaxScalerModel.java | 183 +++++++++++++++++
 .../minmaxscaler/MinMaxScalerModelData.java}       |  69 +++----
 .../feature/minmaxscaler/MinMaxScalerParams.java   |  62 ++++++
 .../apache/flink/ml/feature/MinMaxScalerTest.java  | 218 +++++++++++++++++++++
 6 files changed, 699 insertions(+), 46 deletions(-)

diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
index 89051e6..4bf0adb 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
@@ -82,13 +82,11 @@ public class KnnModelData {
     /** Encoder for {@link KnnModelData}. */
     public static class ModelDataEncoder implements Encoder<KnnModelData> {
         @Override
-        public void encode(KnnModelData knnModelData, OutputStream outputStream)
-                throws IOException {
+        public void encode(KnnModelData modelData, OutputStream outputStream) throws IOException {
             DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream);
-            DenseMatrixSerializer.INSTANCE.serialize(knnModelData.packedFeatures, dataOutputView);
-            DenseVectorSerializer.INSTANCE.serialize(
-                    knnModelData.featureNormSquares, dataOutputView);
-            DenseVectorSerializer.INSTANCE.serialize(knnModelData.labels, dataOutputView);
+            DenseMatrixSerializer.INSTANCE.serialize(modelData.packedFeatures, dataOutputView);
+            DenseVectorSerializer.INSTANCE.serialize(modelData.featureNormSquares, dataOutputView);
+            DenseVectorSerializer.INSTANCE.serialize(modelData.labels, dataOutputView);
         }
     }
 
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
new file mode 100644
index 0000000..19a9f6f
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.minmaxscaler;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the MinMaxScaler algorithm. This algorithm rescales feature values
+ * to a common range [min, max] which defined by user.
+ *
+ * <blockquote>
+ *
+ * $$ Rescaled(value) = \frac{value - E_{min}}{E_{max} - E_{min}} * (max - min) + min $$
+ *
+ * </blockquote>
+ *
+ * <p>For the case \(E_{max} == E_{min}\), \(Rescaled(value) = 0.5 * (max + min)\).
+ *
+ * <p>See https://en.wikipedia.org/wiki/Feature_scaling#Rescaling_(min-max_normalization).
+ */
+public class MinMaxScaler
+        implements Estimator<MinMaxScaler, MinMaxScalerModel>, MinMaxScalerParams<MinMaxScaler> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public MinMaxScaler() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public MinMaxScalerModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        final String featureCol = getFeaturesCol();
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<DenseVector> features =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                (MapFunction<Row, DenseVector>)
+                                        value -> (DenseVector) value.getField(featureCol));
+        DataStream<DenseVector> minMaxValues =
+                features.transform(
+                                "reduceInEachPartition",
+                                features.getType(),
+                                new MinMaxReduceFunctionOperator())
+                        .transform(
+                                "reduceInFinalPartition",
+                                features.getType(),
+                                new MinMaxReduceFunctionOperator())
+                        .setParallelism(1);
+        DataStream<MinMaxScalerModelData> modelData =
+                DataStreamUtils.mapPartition(
+                        minMaxValues,
+                        new RichMapPartitionFunction<DenseVector, MinMaxScalerModelData>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<DenseVector> values,
+                                    Collector<MinMaxScalerModelData> out) {
+                                Iterator<DenseVector> iter = values.iterator();
+                                DenseVector minVector = iter.next();
+                                DenseVector maxVector = iter.next();
+                                out.collect(new MinMaxScalerModelData(minVector, maxVector));
+                            }
+                        });
+
+        MinMaxScalerModel model =
+                new MinMaxScalerModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, getParamMap());
+        return model;
+    }
+
+    /**
+     * A stream operator to compute the min and max values in each partition of the input bounded
+     * data stream.
+     */
+    private static class MinMaxReduceFunctionOperator extends AbstractStreamOperator<DenseVector>
+            implements OneInputStreamOperator<DenseVector, DenseVector>, BoundedOneInput {
+        private ListState<DenseVector> minState;
+        private ListState<DenseVector> maxState;
+
+        private DenseVector minVector;
+        private DenseVector maxVector;
+
+        @Override
+        public void endInput() {
+            if (minVector != null) {
+                output.collect(new StreamRecord<>(minVector));
+                output.collect(new StreamRecord<>(maxVector));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<DenseVector> streamRecord) {
+            DenseVector currentValue = streamRecord.getValue();
+            if (minVector == null) {
+                int vecSize = currentValue.size();
+                minVector = new DenseVector(vecSize);
+                maxVector = new DenseVector(vecSize);
+                System.arraycopy(currentValue.values, 0, minVector.values, 0, vecSize);
+                System.arraycopy(currentValue.values, 0, maxVector.values, 0, vecSize);
+            } else {
+                Preconditions.checkArgument(
+                        currentValue.size() == maxVector.size(),
+                        "CurrentValue should has same size with maxVector.");
+                for (int i = 0; i < currentValue.size(); ++i) {
+                    minVector.values[i] = Math.min(minVector.values[i], currentValue.values[i]);
+                    maxVector.values[i] = Math.max(maxVector.values[i], currentValue.values[i]);
+                }
+            }
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            minState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "minState", TypeInformation.of(DenseVector.class)));
+            maxState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "maxState", TypeInformation.of(DenseVector.class)));
+
+            OperatorStateUtils.getUniqueElement(minState, "minState").ifPresent(x -> minVector = x);
+            OperatorStateUtils.getUniqueElement(maxState, "maxState").ifPresent(x -> maxVector = x);
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            minState.clear();
+            maxState.clear();
+            if (minVector != null) {
+                minState.add(minVector);
+                maxState.add(maxVector);
+            }
+        }
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static MinMaxScaler load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
new file mode 100644
index 0000000..762d74a
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.minmaxscaler;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which do a minMax scaler operation using the model data computed by {@link MinMaxScaler}.
+ */
+public class MinMaxScalerModel
+        implements Model<MinMaxScalerModel>, MinMaxScalerParams<MinMaxScalerModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public MinMaxScalerModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public MinMaxScalerModel setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<MinMaxScalerModelData> minMaxScalerModel =
+                MinMaxScalerModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                TypeInformation.of(DenseVector.class)),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, minMaxScalerModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictOutputFunction(
+                                            broadcastModelKey,
+                                            getMax(),
+                                            getMin(),
+                                            getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                MinMaxScalerModelData.getModelDataStream(modelDataTable),
+                path,
+                new MinMaxScalerModelData.ModelDataEncoder());
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return MinMaxScalerModel model.
+     */
+    public static MinMaxScalerModel load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        MinMaxScalerModel model = ReadWriteUtils.loadStageParam(path);
+        DataStream<MinMaxScalerModelData> modelData =
+                ReadWriteUtils.loadModelData(
+                        env, path, new MinMaxScalerModelData.ModelDataDecoder());
+        return model.setModelData(tEnv.fromDataStream(modelData));
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+        private final String featureCol;
+        private final String broadcastKey;
+        private final double upperBound;
+        private final double lowerBound;
+        private DenseVector scaleVector;
+        private DenseVector offsetVector;
+
+        public PredictOutputFunction(
+                String broadcastKey, double upperBound, double lowerBound, String featureCol) {
+            this.upperBound = upperBound;
+            this.lowerBound = lowerBound;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (scaleVector == null) {
+                MinMaxScalerModelData minMaxScalerModelData =
+                        (MinMaxScalerModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+                DenseVector minVector = minMaxScalerModelData.minVector;
+                DenseVector maxVector = minMaxScalerModelData.maxVector;
+                scaleVector = new DenseVector(minVector.size());
+                offsetVector = new DenseVector(minVector.size());
+                for (int i = 0; i < maxVector.size(); ++i) {
+                    if (Math.abs(minVector.values[i] - maxVector.values[i]) < 1.0e-5) {
+                        scaleVector.values[i] = 0.0;
+                        offsetVector.values[i] = (upperBound + lowerBound) / 2;
+                    } else {
+                        scaleVector.values[i] =
+                                (upperBound - lowerBound)
+                                        / (maxVector.values[i] - minVector.values[i]);
+                        offsetVector.values[i] =
+                                lowerBound - minVector.values[i] * scaleVector.values[i];
+                    }
+                }
+            }
+            DenseVector feature = (DenseVector) row.getField(featureCol);
+            DenseVector outputVector = new DenseVector(scaleVector.size());
+            for (int i = 0; i < scaleVector.size(); ++i) {
+                outputVector.values[i] =
+                        feature.values[i] * scaleVector.values[i] + offsetVector.values[i];
+            }
+            return Row.join(row, Row.of(outputVector));
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java
similarity index 56%
copy from flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
copy to flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java
index 89051e6..301eadd 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java
@@ -16,7 +16,7 @@
  * limitations under the License.
  */
 
-package org.apache.flink.ml.classification.knn;
+package org.apache.flink.ml.feature.minmaxscaler;
 
 import org.apache.flink.api.common.serialization.Encoder;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
@@ -27,9 +27,7 @@ import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
-import org.apache.flink.ml.linalg.DenseMatrix;
 import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.table.api.Table;
@@ -41,24 +39,21 @@ import java.io.IOException;
 import java.io.OutputStream;
 
 /**
- * Model data of {@link KnnModel}.
+ * Model data of {@link MinMaxScalerModel}.
  *
  * <p>This class also provides methods to convert model data from Table to a data stream, and
  * classes to save/load model data.
  */
-public class KnnModelData {
+public class MinMaxScalerModelData {
+    public DenseVector minVector;
 
-    public DenseMatrix packedFeatures;
-    public DenseVector featureNormSquares;
-    public DenseVector labels;
+    public DenseVector maxVector;
 
-    public KnnModelData() {}
+    public MinMaxScalerModelData() {}
 
-    public KnnModelData(
-            DenseMatrix packedFeatures, DenseVector featureNormSquares, DenseVector labels) {
-        this.packedFeatures = packedFeatures;
-        this.featureNormSquares = featureNormSquares;
-        this.labels = labels;
+    public MinMaxScalerModelData(DenseVector minVector, DenseVector maxVector) {
+        this.minVector = minVector;
+        this.maxVector = maxVector;
     }
 
     /**
@@ -67,47 +62,41 @@ public class KnnModelData {
      * @param modelDataTable The table model data.
      * @return The data stream model data.
      */
-    public static DataStream<KnnModelData> getModelDataStream(Table modelDataTable) {
+    public static DataStream<MinMaxScalerModelData> getModelDataStream(Table modelDataTable) {
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
         return tEnv.toDataStream(modelDataTable)
                 .map(
                         x ->
-                                new KnnModelData(
-                                        (DenseMatrix) x.getField(0),
-                                        (DenseVector) x.getField(1),
-                                        (DenseVector) x.getField(2)));
+                                new MinMaxScalerModelData(
+                                        (DenseVector) x.getField(0), (DenseVector) x.getField(1)));
     }
 
-    /** Encoder for {@link KnnModelData}. */
-    public static class ModelDataEncoder implements Encoder<KnnModelData> {
+    /** Encoder for {@link MinMaxScalerModelData}. */
+    public static class ModelDataEncoder implements Encoder<MinMaxScalerModelData> {
         @Override
-        public void encode(KnnModelData knnModelData, OutputStream outputStream)
+        public void encode(MinMaxScalerModelData modelData, OutputStream outputStream)
                 throws IOException {
             DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream);
-            DenseMatrixSerializer.INSTANCE.serialize(knnModelData.packedFeatures, dataOutputView);
-            DenseVectorSerializer.INSTANCE.serialize(
-                    knnModelData.featureNormSquares, dataOutputView);
-            DenseVectorSerializer.INSTANCE.serialize(knnModelData.labels, dataOutputView);
+            DenseVectorSerializer.INSTANCE.serialize(modelData.minVector, dataOutputView);
+            DenseVectorSerializer.INSTANCE.serialize(modelData.maxVector, dataOutputView);
         }
     }
 
-    /** Decoder for {@link KnnModelData}. */
-    public static class ModelDataDecoder extends SimpleStreamFormat<KnnModelData> {
+    /** Decoder for {@link MinMaxScalerModelData}. */
+    public static class ModelDataDecoder extends SimpleStreamFormat<MinMaxScalerModelData> {
         @Override
-        public Reader<KnnModelData> createReader(Configuration config, FSDataInputStream stream) {
-            return new Reader<KnnModelData>() {
-
-                private final DataInputView source = new DataInputViewStreamWrapper(stream);
+        public Reader<MinMaxScalerModelData> createReader(
+                Configuration config, FSDataInputStream stream) {
+            return new Reader<MinMaxScalerModelData>() {
 
                 @Override
-                public KnnModelData read() throws IOException {
+                public MinMaxScalerModelData read() throws IOException {
+                    DataInputView source = new DataInputViewStreamWrapper(stream);
                     try {
-                        DenseMatrix matrix = DenseMatrixSerializer.INSTANCE.deserialize(source);
-                        DenseVector normSquares =
-                                DenseVectorSerializer.INSTANCE.deserialize(source);
-                        DenseVector labels = DenseVectorSerializer.INSTANCE.deserialize(source);
-                        return new KnnModelData(matrix, normSquares, labels);
+                        DenseVector minVector = DenseVectorSerializer.INSTANCE.deserialize(source);
+                        DenseVector maxVector = DenseVectorSerializer.INSTANCE.deserialize(source);
+                        return new MinMaxScalerModelData(minVector, maxVector);
                     } catch (EOFException e) {
                         return null;
                     }
@@ -121,8 +110,8 @@ public class KnnModelData {
         }
 
         @Override
-        public TypeInformation<KnnModelData> getProducedType() {
-            return TypeInformation.of(KnnModelData.class);
+        public TypeInformation<MinMaxScalerModelData> getProducedType() {
+            return TypeInformation.of(MinMaxScalerModelData.class);
         }
     }
 }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerParams.java
new file mode 100644
index 0000000..aade500
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerParams.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.minmaxscaler;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params for {@link MinMaxScaler}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface MinMaxScalerParams<T> extends HasFeaturesCol<T>, HasPredictionCol<T> {
+    Param<Double> MIN =
+            new DoubleParam(
+                    "min",
+                    "Lower bound of the output feature range.",
+                    0.0,
+                    ParamValidators.notNull());
+
+    default Double getMin() {
+        return get(MIN);
+    }
+
+    default T setMin(Double value) {
+        return set(MIN, value);
+    }
+
+    Param<Double> MAX =
+            new DoubleParam(
+                    "max",
+                    "Upper bound of the output feature range.",
+                    1.0,
+                    ParamValidators.notNull());
+
+    default Double getMax() {
+        return get(MAX);
+    }
+
+    default T setMax(Double value) {
+        return set(MAX, value);
+    }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
new file mode 100644
index 0000000..24ec885
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler;
+import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link MinMaxScaler} and {@link MinMaxScalerModel}. */
+public class MinMaxScalerTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+    private Table predictDataTable;
+    private static final List<Row> TRAIN_DATA =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(Vectors.dense(0.0, 3.0)),
+                            Row.of(Vectors.dense(2.1, 0.0)),
+                            Row.of(Vectors.dense(4.1, 5.1)),
+                            Row.of(Vectors.dense(6.1, 8.1)),
+                            Row.of(Vectors.dense(200, 400))));
+    private static final List<Row> PREDICT_DATA =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(Vectors.dense(150.0, 90.0)),
+                            Row.of(Vectors.dense(50.0, 40.0)),
+                            Row.of(Vectors.dense(100.0, 50.0))));
+    private static final double EPS = 1.0e-5;
+    private static final List<DenseVector> EXPECTED_DATA =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Vectors.dense(0.25, 0.1),
+                            Vectors.dense(0.5, 0.125),
+                            Vectors.dense(0.75, 0.225)));
+
+    /** Note: this comparator imposes orderings that are inconsistent with equals. */
+    private static int compare(DenseVector first, DenseVector second) {
+        for (int i = 0; i < first.size(); i++) {
+            int cmp = Double.compare(first.get(i), second.get(i));
+            if (cmp != 0) {
+                return cmp;
+            }
+        }
+        return 0;
+    }
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("features");
+        predictDataTable = tEnv.fromDataStream(env.fromCollection(PREDICT_DATA)).as("features");
+    }
+
+    private static void verifyPredictionResult(
+            Table output, String outputCol, List<DenseVector> expected) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+        DataStream<DenseVector> stream =
+                tEnv.toDataStream(output)
+                        .map(
+                                (MapFunction<Row, DenseVector>)
+                                        row -> (DenseVector) row.getField(outputCol));
+        List<DenseVector> result = IteratorUtils.toList(stream.executeAndCollect());
+        result.sort(MinMaxScalerTest::compare);
+        assertEquals(expected, result);
+    }
+
+    @Test
+    public void testParam() {
+        MinMaxScaler minMaxScaler = new MinMaxScaler();
+        assertEquals("features", minMaxScaler.getFeaturesCol());
+        assertEquals("prediction", minMaxScaler.getPredictionCol());
+        assertEquals(0.0, minMaxScaler.getMin(), EPS);
+        assertEquals(1.0, minMaxScaler.getMax(), EPS);
+        minMaxScaler
+                .setFeaturesCol("test_features")
+                .setPredictionCol("test_output")
+                .setMin(1.0)
+                .setMax(4.0);
+        assertEquals("test_features", minMaxScaler.getFeaturesCol());
+        assertEquals(1.0, minMaxScaler.getMin(), EPS);
+        assertEquals(4.0, minMaxScaler.getMax(), EPS);
+        assertEquals("test_output", minMaxScaler.getPredictionCol());
+    }
+
+    @Test
+    public void testFeaturePredictionParam() {
+        MinMaxScaler minMaxScaler =
+                new MinMaxScaler()
+                        .setFeaturesCol("test_features")
+                        .setPredictionCol("test_output")
+                        .setMin(1.0)
+                        .setMax(4.0);
+
+        MinMaxScalerModel model = minMaxScaler.fit(trainDataTable.as("test_features"));
+        Table output = model.transform(predictDataTable.as("test_features"))[0];
+        assertEquals(
+                Arrays.asList("test_features", "test_output"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testMaxValueEqualsMinValueButPredictValueNotEquals() throws Exception {
+        List<Row> trainData =
+                new ArrayList<>(Collections.singletonList(Row.of(Vectors.dense(40.0, 80.0))));
+        Table trainTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("features");
+        List<Row> predictData =
+                new ArrayList<>(Collections.singletonList(Row.of(Vectors.dense(30.0, 50.0))));
+        Table predictDataTable =
+                tEnv.fromDataStream(env.fromCollection(predictData)).as("features");
+        MinMaxScaler minMaxScaler = new MinMaxScaler().setMax(10.0).setMin(0.0);
+        MinMaxScalerModel model = minMaxScaler.fit(trainTable);
+        Table result = model.transform(predictDataTable)[0];
+        verifyPredictionResult(
+                result,
+                minMaxScaler.getPredictionCol(),
+                Collections.singletonList(Vectors.dense(5.0, 5.0)));
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        MinMaxScaler minMaxScaler = new MinMaxScaler();
+        MinMaxScalerModel minMaxScalerModel = minMaxScaler.fit(trainDataTable);
+        Table output = minMaxScalerModel.transform(predictDataTable)[0];
+        verifyPredictionResult(output, minMaxScaler.getPredictionCol(), EXPECTED_DATA);
+    }
+
+    @Test
+    public void testSaveLoadAndPredict() throws Exception {
+        MinMaxScaler minMaxScaler = new MinMaxScaler();
+        MinMaxScaler loadedMinMaxScaler =
+                StageTestUtils.saveAndReload(
+                        env, minMaxScaler, tempFolder.newFolder().getAbsolutePath());
+        MinMaxScalerModel model = loadedMinMaxScaler.fit(trainDataTable);
+        MinMaxScalerModel loadedModel =
+                StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
+        assertEquals(
+                Arrays.asList("minVector", "maxVector"),
+                model.getModelData()[0].getResolvedSchema().getColumnNames());
+        Table output = loadedModel.transform(predictDataTable)[0];
+        verifyPredictionResult(output, minMaxScaler.getPredictionCol(), EXPECTED_DATA);
+    }
+
+    @Test
+    public void testGetModelData() throws Exception {
+        MinMaxScaler minMaxScaler = new MinMaxScaler();
+        MinMaxScalerModel minMaxScalerModel = minMaxScaler.fit(trainDataTable);
+        Table modelData = minMaxScalerModel.getModelData()[0];
+        assertEquals(
+                Arrays.asList("minVector", "maxVector"),
+                modelData.getResolvedSchema().getColumnNames());
+        DataStream<Row> output = tEnv.toDataStream(modelData);
+        List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect());
+        assertEquals(new DenseVector(new double[] {0.0, 0.0}), modelRows.get(0).getField(0));
+        assertEquals(new DenseVector(new double[] {200.0, 400.0}), modelRows.get(0).getField(1));
+    }
+
+    @Test
+    public void testSetModelData() throws Exception {
+        MinMaxScaler minMaxScaler = new MinMaxScaler();
+        MinMaxScalerModel modelA = minMaxScaler.fit(trainDataTable);
+        Table modelData = modelA.getModelData()[0];
+        MinMaxScalerModel modelB = new MinMaxScalerModel().setModelData(modelData);
+        ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+        Table output = modelB.transform(predictDataTable)[0];
+        verifyPredictionResult(output, minMaxScaler.getPredictionCol(), EXPECTED_DATA);
+    }
+}