You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/08/23 06:03:21 UTC
[flink-ml] branch master updated: [FLINK-28943] Add Transformer and Estimator for MaxAbsScaler
This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 3aa2175 [FLINK-28943] Add Transformer and Estimator for MaxAbsScaler
3aa2175 is described below
commit 3aa21751ddb1b561c44ec00ac5b43646ead6fcfd
Author: weibo <wb...@pku.edu.cn>
AuthorDate: Tue Aug 23 14:03:17 2022 +0800
[FLINK-28943] Add Transformer and Estimator for MaxAbsScaler
This closes #142.
---
.../ml/examples/feature/MaxAbsScalerExample.java | 72 ++++++
.../ml/feature/maxabsscaler/MaxAbsScaler.java | 180 +++++++++++++++
.../MaxAbsScalerModel.java} | 90 +++-----
.../maxabsscaler/MaxAbsScalerModelData.java | 112 +++++++++
.../feature/maxabsscaler/MaxAbsScalerParams.java | 29 +++
.../ml/feature/minmaxscaler/MinMaxScalerModel.java | 4 +-
.../apache/flink/ml/feature/MaxAbsScalerTest.java | 254 +++++++++++++++++++++
.../examples/ml/feature/maxabsscaler_example.py | 79 +++++++
.../pyflink/ml/lib/feature/maxabsscaler.py | 71 ++++++
.../pyflink/ml/lib/feature/minmaxscaler.py | 3 +-
.../ml/lib/feature/tests/test_maxabsscaler.py | 99 ++++++++
11 files changed, 936 insertions(+), 57 deletions(-)
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MaxAbsScalerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MaxAbsScalerExample.java
new file mode 100644
index 0000000..cd53394
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/MaxAbsScalerExample.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScaler;
+import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that trains a MaxAbsScaler model and uses it for feature engineering. */
+public class MaxAbsScalerExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input training and prediction data.
+ DataStream<Row> trainStream =
+ env.fromElements(
+ Row.of(Vectors.dense(0.0, 3.0)),
+ Row.of(Vectors.dense(2.1, 0.0)),
+ Row.of(Vectors.dense(4.1, 5.1)),
+ Row.of(Vectors.dense(6.1, 8.1)),
+ Row.of(Vectors.dense(200, 400)));
+ Table trainTable = tEnv.fromDataStream(trainStream).as("input");
+
+ DataStream<Row> predictStream =
+ env.fromElements(
+ Row.of(Vectors.dense(150.0, 90.0)),
+ Row.of(Vectors.dense(50.0, 40.0)),
+ Row.of(Vectors.dense(100.0, 50.0)));
+ Table predictTable = tEnv.fromDataStream(predictStream).as("input");
+
+ // Creates a MaxAbsScaler object and initializes its parameters.
+ MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
+
+ // Trains the MaxAbsScaler Model.
+ MaxAbsScalerModel maxAbsScalerModel = maxAbsScaler.fit(trainTable);
+
+ // Uses the MaxAbsScaler Model for predictions.
+ Table outputTable = maxAbsScalerModel.transform(predictTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+ DenseVector inputValue = (DenseVector) row.getField(maxAbsScaler.getInputCol());
+ DenseVector outputValue = (DenseVector) row.getField(maxAbsScaler.getOutputCol());
+ System.out.printf("Input Value: %-15s\tOutput Value: %s\n", inputValue, outputValue);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java
new file mode 100644
index 0000000..0f3a7b8
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScaler.java
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.maxabsscaler;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the MaxAbsScaler algorithm. This algorithm rescales feature values
+ * to the range [-1, 1] by dividing through the largest maximum absolute value in each feature. It
+ * does not shift/center the data and thus does not destroy any sparsity.
+ */
+public class MaxAbsScaler
+ implements Estimator<MaxAbsScaler, MaxAbsScalerModel>, MaxAbsScalerParams<MaxAbsScaler> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public MaxAbsScaler() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public MaxAbsScalerModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ final String inputCol = getInputCol();
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ DataStream<Vector> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ (MapFunction<Row, Vector>)
+ value -> ((Vector) value.getField(inputCol)));
+
+ DataStream<Vector> maxAbsValues =
+ inputData
+ .transform(
+ "reduceInEachPartition",
+ inputData.getType(),
+ new MaxAbsReduceFunctionOperator())
+ .transform(
+ "reduceInFinalPartition",
+ inputData.getType(),
+ new MaxAbsReduceFunctionOperator())
+ .setParallelism(1);
+
+ DataStream<MaxAbsScalerModelData> modelData =
+ maxAbsValues.map(
+ (MapFunction<Vector, MaxAbsScalerModelData>)
+ vector -> new MaxAbsScalerModelData((DenseVector) vector));
+
+ MaxAbsScalerModel model =
+ new MaxAbsScalerModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ /**
+ * A stream operator to compute the maximum absolute values in each partition of the input
+ * bounded data stream.
+ */
+ private static class MaxAbsReduceFunctionOperator extends AbstractStreamOperator<Vector>
+ implements OneInputStreamOperator<Vector, Vector>, BoundedOneInput {
+ private ListState<DenseVector> maxAbsState;
+ private DenseVector maxAbsVector;
+
+ @Override
+ public void endInput() {
+ if (maxAbsVector != null) {
+ output.collect(new StreamRecord<>(maxAbsVector));
+ }
+ }
+
+ @Override
+ public void processElement(StreamRecord<Vector> streamRecord) {
+ Vector currentValue = streamRecord.getValue();
+ if (currentValue == null) {
+ throw new RuntimeException("Input column data cannot be null.");
+ }
+
+ maxAbsVector =
+ maxAbsVector == null ? new DenseVector(currentValue.size()) : maxAbsVector;
+ Preconditions.checkArgument(
+ currentValue.size() == maxAbsVector.size(),
+ "The training data should all have same dimensions.");
+
+ if (currentValue instanceof DenseVector) {
+ double[] values = ((DenseVector) currentValue).values;
+ for (int i = 0; i < currentValue.size(); ++i) {
+ maxAbsVector.values[i] = Math.max(maxAbsVector.values[i], Math.abs(values[i]));
+ }
+ } else if (currentValue instanceof SparseVector) {
+ int[] indices = ((SparseVector) currentValue).indices;
+ double[] values = ((SparseVector) currentValue).values;
+
+ for (int i = 0; i < indices.length; ++i) {
+ maxAbsVector.values[indices[i]] =
+ Math.max(maxAbsVector.values[indices[i]], Math.abs(values[i]));
+ }
+ }
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws Exception {
+ super.initializeState(context);
+ maxAbsState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "maxAbsState", DenseVectorTypeInfo.INSTANCE));
+
+ OperatorStateUtils.getUniqueElement(maxAbsState, "maxAbsState")
+ .ifPresent(x -> maxAbsVector = x);
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws Exception {
+ super.snapshotState(context);
+ maxAbsState.clear();
+ if (maxAbsVector != null) {
+ maxAbsState.add(maxAbsVector);
+ }
+ }
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static MaxAbsScaler load(StreamTableEnvironment tEnv, String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModel.java
similarity index 60%
copy from flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
copy to flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModel.java
index 59b6aa0..5f5d7e4 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModel.java
@@ -16,16 +16,17 @@
* limitations under the License.
*/
-package org.apache.flink.ml.feature.minmaxscaler;
+package org.apache.flink.ml.feature.maxabsscaler;
import org.apache.flink.api.common.functions.RichMapFunction;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.common.broadcast.BroadcastUtils;
import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
@@ -43,20 +44,18 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
-/**
- * A Model which do a minMax scaler operation using the model data computed by {@link MinMaxScaler}.
- */
-public class MinMaxScalerModel
- implements Model<MinMaxScalerModel>, MinMaxScalerParams<MinMaxScalerModel> {
+/** A Model which transforms data using the model data computed by {@link MaxAbsScaler}. */
+public class MaxAbsScalerModel
+ implements Model<MaxAbsScalerModel>, MaxAbsScalerParams<MaxAbsScalerModel> {
private final Map<Param<?>, Object> paramMap = new HashMap<>();
private Table modelDataTable;
- public MinMaxScalerModel() {
+ public MaxAbsScalerModel() {
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
@Override
- public MinMaxScalerModel setModelData(Table... inputs) {
+ public MaxAbsScalerModel setModelData(Table... inputs) {
modelDataTable = inputs[0];
return this;
}
@@ -72,28 +71,29 @@ public class MinMaxScalerModel
Preconditions.checkArgument(inputs.length == 1);
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
DataStream<Row> data = tEnv.toDataStream(inputs[0]);
- DataStream<MinMaxScalerModelData> minMaxScalerModel =
- MinMaxScalerModelData.getModelDataStream(modelDataTable);
+ DataStream<MaxAbsScalerModelData> maxAbsScalerModel =
+ MaxAbsScalerModelData.getModelDataStream(modelDataTable);
+
final String broadcastModelKey = "broadcastModelKey";
RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
RowTypeInfo outputTypeInfo =
new RowTypeInfo(
- ArrayUtils.addAll(
- inputTypeInfo.getFieldTypes(),
- TypeInformation.of(DenseVector.class)),
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE),
ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
DataStream<Row> output =
BroadcastUtils.withBroadcastStream(
Collections.singletonList(data),
- Collections.singletonMap(broadcastModelKey, minMaxScalerModel),
+ Collections.singletonMap(broadcastModelKey, maxAbsScalerModel),
inputList -> {
DataStream input = inputList.get(0);
return input.map(
- new PredictOutputFunction(
- broadcastModelKey, getMax(), getMin(), getInputCol()),
+ new PredictOutputFunction(broadcastModelKey, getInputCol()),
outputTypeInfo);
});
+
return new Table[] {tEnv.fromDataStream(output)};
}
@@ -106,9 +106,9 @@ public class MinMaxScalerModel
public void save(String path) throws IOException {
ReadWriteUtils.saveMetadata(this, path);
ReadWriteUtils.saveModelData(
- MinMaxScalerModelData.getModelDataStream(modelDataTable),
+ MaxAbsScalerModelData.getModelDataStream(modelDataTable),
path,
- new MinMaxScalerModelData.ModelDataEncoder());
+ new MaxAbsScalerModelData.ModelDataEncoder());
}
/**
@@ -116,30 +116,25 @@ public class MinMaxScalerModel
*
* @param tEnv Stream table environment.
* @param path Model path.
- * @return MinMaxScalerModel model.
+ * @return MaxAbsScalerModel model.
*/
- public static MinMaxScalerModel load(StreamTableEnvironment tEnv, String path)
+ public static MaxAbsScalerModel load(StreamTableEnvironment tEnv, String path)
throws IOException {
- MinMaxScalerModel model = ReadWriteUtils.loadStageParam(path);
+ MaxAbsScalerModel model = ReadWriteUtils.loadStageParam(path);
+
Table modelDataTable =
ReadWriteUtils.loadModelData(
- tEnv, path, new MinMaxScalerModelData.ModelDataDecoder());
+ tEnv, path, new MaxAbsScalerModelData.ModelDataDecoder());
return model.setModelData(modelDataTable);
}
- /** This operator loads model data and predicts result. */
+ /** This function loads model data and predicts result. */
private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
private final String inputCol;
private final String broadcastKey;
- private final double upperBound;
- private final double lowerBound;
private DenseVector scaleVector;
- private DenseVector offsetVector;
- public PredictOutputFunction(
- String broadcastKey, double upperBound, double lowerBound, String inputCol) {
- this.upperBound = upperBound;
- this.lowerBound = lowerBound;
+ public PredictOutputFunction(String broadcastKey, String inputCol) {
this.broadcastKey = broadcastKey;
this.inputCol = inputCol;
}
@@ -147,32 +142,23 @@ public class MinMaxScalerModel
@Override
public Row map(Row row) {
if (scaleVector == null) {
- MinMaxScalerModelData minMaxScalerModelData =
- (MinMaxScalerModelData)
+ MaxAbsScalerModelData maxAbsScalerModelData =
+ (MaxAbsScalerModelData)
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
- DenseVector minVector = minMaxScalerModelData.minVector;
- DenseVector maxVector = minMaxScalerModelData.maxVector;
- scaleVector = new DenseVector(minVector.size());
- offsetVector = new DenseVector(minVector.size());
- for (int i = 0; i < maxVector.size(); ++i) {
- if (Math.abs(minVector.values[i] - maxVector.values[i]) < 1.0e-5) {
- scaleVector.values[i] = 0.0;
- offsetVector.values[i] = (upperBound + lowerBound) / 2;
+ scaleVector = maxAbsScalerModelData.maxVector;
+
+ for (int i = 0; i < scaleVector.size(); ++i) {
+ if (scaleVector.values[i] != 0) {
+ scaleVector.values[i] = 1.0 / scaleVector.values[i];
} else {
- scaleVector.values[i] =
- (upperBound - lowerBound)
- / (maxVector.values[i] - minVector.values[i]);
- offsetVector.values[i] =
- lowerBound - minVector.values[i] * scaleVector.values[i];
+ scaleVector.values[i] = 1.0;
}
}
}
- DenseVector inputVec = ((Vector) row.getField(inputCol)).toDense();
- DenseVector outputVec = new DenseVector(scaleVector.size());
- for (int i = 0; i < scaleVector.size(); ++i) {
- outputVec.values[i] =
- inputVec.values[i] * scaleVector.values[i] + offsetVector.values[i];
- }
+
+ Vector inputVec = row.getFieldAs(inputCol);
+ Vector outputVec = inputVec.clone();
+ BLAS.hDot(scaleVector, outputVec);
return Row.join(row, Row.of(outputVec));
}
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModelData.java
new file mode 100644
index 0000000..4c1e76d
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerModelData.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.maxabsscaler;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link MaxAbsScalerModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to a data stream, and
+ * classes to save/load model data.
+ */
+public class MaxAbsScalerModelData {
+ public DenseVector maxVector;
+
+ public MaxAbsScalerModelData() {}
+
+ public MaxAbsScalerModelData(DenseVector maxVector) {
+ this.maxVector = maxVector;
+ }
+
+ /**
+ * Converts the table model to a data stream.
+ *
+ * @param modelDataTable The table model data.
+ * @return The data stream model data.
+ */
+ public static DataStream<MaxAbsScalerModelData> getModelDataStream(Table modelDataTable) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+ return tEnv.toDataStream(modelDataTable)
+ .map(x -> new MaxAbsScalerModelData((DenseVector) x.getField(0)));
+ }
+
+ /** Encoder for {@link MaxAbsScalerModelData}. */
+ public static class ModelDataEncoder implements Encoder<MaxAbsScalerModelData> {
+ private final DenseVectorSerializer serializer = new DenseVectorSerializer();
+
+ @Override
+ public void encode(MaxAbsScalerModelData modelData, OutputStream outputStream)
+ throws IOException {
+ DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream);
+ serializer.serialize(modelData.maxVector, dataOutputView);
+ }
+ }
+
+ /** Decoder for {@link MaxAbsScalerModelData}. */
+ public static class ModelDataDecoder extends SimpleStreamFormat<MaxAbsScalerModelData> {
+ @Override
+ public Reader<MaxAbsScalerModelData> createReader(
+ Configuration config, FSDataInputStream stream) {
+ return new Reader<MaxAbsScalerModelData>() {
+ private final DenseVectorSerializer serializer = new DenseVectorSerializer();
+
+ @Override
+ public MaxAbsScalerModelData read() throws IOException {
+ DataInputView source = new DataInputViewStreamWrapper(stream);
+ try {
+ DenseVector maxVector = serializer.deserialize(source);
+ return new MaxAbsScalerModelData(maxVector);
+ } catch (EOFException e) {
+ return null;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ stream.close();
+ }
+ };
+ }
+
+ @Override
+ public TypeInformation<MaxAbsScalerModelData> getProducedType() {
+ return TypeInformation.of(MaxAbsScalerModelData.class);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerParams.java
new file mode 100644
index 0000000..bec4110
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/maxabsscaler/MaxAbsScalerParams.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.maxabsscaler;
+
+import org.apache.flink.ml.common.param.HasInputCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
+
+/**
+ * Params for {@link MaxAbsScaler}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface MaxAbsScalerParams<T> extends HasInputCol<T>, HasOutputCol<T> {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
index 59b6aa0..858c0f4 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
@@ -43,9 +43,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
-/**
- * A Model which do a minMax scaler operation using the model data computed by {@link MinMaxScaler}.
- */
+/** A Model which transforms data using the model data computed by {@link MinMaxScaler}. */
public class MinMaxScalerModel
implements Model<MinMaxScalerModel>, MinMaxScalerParams<MinMaxScalerModel> {
private final Map<Param<?>, Object> paramMap = new HashMap<>();
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java
new file mode 100644
index 0000000..68e443d
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MaxAbsScalerTest.java
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScaler;
+import org.apache.flink.ml.feature.maxabsscaler.MaxAbsScalerModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.test.util.TestBaseUtils.compareResultCollections;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests {@link MaxAbsScaler} and {@link MaxAbsScalerModel}. */
+public class MaxAbsScalerTest {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamTableEnvironment tEnv;
+ private StreamExecutionEnvironment env;
+
+ private Table trainDataTable;
+ private Table predictDataTable;
+ private Table trainSparseDataTable;
+ private Table predictSparseDataTable;
+
+ private static final List<Row> TRAIN_DATA =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(0.0, 3.0, 0.0)),
+ Row.of(Vectors.dense(2.1, 0.0, 0.0)),
+ Row.of(Vectors.dense(4.1, 5.1, 0.0)),
+ Row.of(Vectors.dense(6.1, 8.1, 0.0)),
+ Row.of(Vectors.dense(200, -400, 0.0))));
+
+ private static final List<Row> PREDICT_DATA =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(150.0, 90.0, 1.0)),
+ Row.of(Vectors.dense(50.0, 40.0, 1.0)),
+ Row.of(Vectors.dense(100.0, 50.0, 0.5))));
+
+ private static final List<Row> TRAIN_SPARSE_DATA =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.sparse(4, new int[] {1, 3}, new double[] {4.0, 3.0})),
+ Row.of(Vectors.sparse(4, new int[] {0, 2}, new double[] {2.0, -6.0})),
+ Row.of(Vectors.sparse(4, new int[] {1, 2}, new double[] {1.0, 3.0})),
+ Row.of(Vectors.sparse(4, new int[] {0, 1}, new double[] {2.0, 8.0})),
+ Row.of(Vectors.sparse(4, new int[] {1, 3}, new double[] {1.0, 5.0}))));
+
+ private static final List<Row> PREDICT_SPARSE_DATA =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.sparse(4, new int[] {0, 1}, new double[] {2.0, 4.0})),
+ Row.of(Vectors.sparse(4, new int[] {0, 2}, new double[] {1.0, 3.0})),
+ Row.of(Vectors.sparse(4, new int[] {}, new double[] {})),
+ Row.of(Vectors.sparse(4, new int[] {1, 3}, new double[] {1.0, 2.0}))));
+
+ private static final List<Vector> EXPECTED_DATA =
+ new ArrayList<>(
+ Arrays.asList(
+ Vectors.dense(0.25, 0.1, 1.0),
+ Vectors.dense(0.5, 0.125, 0.5),
+ Vectors.dense(0.75, 0.225, 1.0)));
+
+ private static final List<Vector> EXPECTED_SPARSE_DATA =
+ new ArrayList<>(
+ Arrays.asList(
+ Vectors.sparse(4, new int[] {0, 1}, new double[] {1.0, 0.5}),
+ Vectors.sparse(4, new int[] {0, 2}, new double[] {0.5, 0.5}),
+ Vectors.sparse(4, new int[] {}, new double[] {}),
+ Vectors.sparse(4, new int[] {1, 3}, new double[] {0.125, 0.4})));
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("input");
+ predictDataTable = tEnv.fromDataStream(env.fromCollection(PREDICT_DATA)).as("input");
+
+ trainSparseDataTable =
+ tEnv.fromDataStream(env.fromCollection(TRAIN_SPARSE_DATA)).as("input");
+ predictSparseDataTable =
+ tEnv.fromDataStream(env.fromCollection(PREDICT_SPARSE_DATA)).as("input");
+ }
+
+ private static void verifyPredictionResult(
+ Table output, String outputCol, List<Vector> expectedData) throws Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+
+ DataStream<Vector> stream =
+ tEnv.toDataStream(output)
+ .map((MapFunction<Row, Vector>) row -> row.getFieldAs(outputCol));
+
+ List<Vector> result = IteratorUtils.toList(stream.executeAndCollect());
+ compareResultCollections(expectedData, result, TestUtils::compare);
+ }
+
+ @Test
+ public void testParam() {
+ MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
+ assertEquals("input", maxAbsScaler.getInputCol());
+ assertEquals("output", maxAbsScaler.getOutputCol());
+
+ maxAbsScaler.setInputCol("test_input").setOutputCol("test_output");
+ assertEquals("test_input", maxAbsScaler.getInputCol());
+ assertEquals("test_output", maxAbsScaler.getOutputCol());
+ }
+
+ @Test
+ public void testOutputSchema() {
+ MaxAbsScaler maxAbsScaler =
+ new MaxAbsScaler().setInputCol("test_input").setOutputCol("test_output");
+
+ MaxAbsScalerModel model = maxAbsScaler.fit(trainDataTable.as("test_input"));
+ Table output = model.transform(predictDataTable.as("test_input"))[0];
+ assertEquals(
+ Arrays.asList("test_input", "test_output"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testFitAndPredict() throws Exception {
+ MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
+ MaxAbsScalerModel maxAbsScalerModel = maxAbsScaler.fit(trainDataTable);
+ Table output = maxAbsScalerModel.transform(predictDataTable)[0];
+ verifyPredictionResult(output, maxAbsScaler.getOutputCol(), EXPECTED_DATA);
+ }
+
+ @Test
+ public void testFitDataWithNullValue() {
+ List<Row> trainData =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(0.0, 3.0)),
+ Row.of(Vectors.dense(2.1, 0.0)),
+ Row.of((Object) null),
+ Row.of(Vectors.dense(6.1, 8.1)),
+ Row.of(Vectors.dense(200, 400))));
+
+ Table trainDataWithInvalidData =
+ tEnv.fromDataStream(env.fromCollection(trainData)).as("input");
+
+ try {
+ MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
+ MaxAbsScalerModel model = maxAbsScaler.fit(trainDataWithInvalidData);
+ IteratorUtils.toList(tEnv.toDataStream(model.getModelData()[0]).executeAndCollect());
+ fail();
+ } catch (Exception e) {
+ assertEquals(
+ "Input column data cannot be null.",
+ ExceptionUtils.getRootCause(e).getMessage());
+ }
+ }
+
+ @Test
+ public void testFitAndPredictSparse() throws Exception {
+ MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
+ MaxAbsScalerModel maxAbsScalerModel = maxAbsScaler.fit(trainSparseDataTable);
+ Table output = maxAbsScalerModel.transform(predictSparseDataTable)[0];
+ verifyPredictionResult(output, maxAbsScaler.getOutputCol(), EXPECTED_SPARSE_DATA);
+ }
+
+ @Test
+ public void testSaveLoadAndPredict() throws Exception {
+ MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
+ MaxAbsScaler loadedMaxAbsScaler =
+ TestUtils.saveAndReload(
+ tEnv, maxAbsScaler, tempFolder.newFolder().getAbsolutePath());
+
+ MaxAbsScalerModel model = loadedMaxAbsScaler.fit(trainDataTable);
+ MaxAbsScalerModel loadedModel =
+ TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
+
+ Table output = loadedModel.transform(predictDataTable)[0];
+ verifyPredictionResult(output, maxAbsScaler.getOutputCol(), EXPECTED_DATA);
+ }
+
+ @Test
+ public void testGetModelData() throws Exception {
+ MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
+ MaxAbsScalerModel maxAbsScalerModel = maxAbsScaler.fit(trainDataTable);
+
+ Table modelData = maxAbsScalerModel.getModelData()[0];
+ assertEquals(
+ Collections.singletonList("maxVector"),
+ modelData.getResolvedSchema().getColumnNames());
+
+ DataStream<Row> output = tEnv.toDataStream(modelData);
+ List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect());
+ assertEquals(
+ new DenseVector(new double[] {200.0, 400.0, 0.0}), modelRows.get(0).getField(0));
+ }
+
+ @Test
+ public void testSetModelData() throws Exception {
+ MaxAbsScaler maxAbsScaler = new MaxAbsScaler();
+ MaxAbsScalerModel modelA = maxAbsScaler.fit(trainDataTable);
+ Table modelData = modelA.getModelData()[0];
+
+ MaxAbsScalerModel modelB = new MaxAbsScalerModel().setModelData(modelData);
+ ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+ Table output = modelB.transform(predictDataTable)[0];
+ verifyPredictionResult(output, maxAbsScaler.getOutputCol(), EXPECTED_DATA);
+ }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/maxabsscaler_example.py b/flink-ml-python/pyflink/examples/ml/feature/maxabsscaler_example.py
new file mode 100644
index 0000000..f6c0008
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/maxabsscaler_example.py
@@ -0,0 +1,79 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that trains a MaxAbsScaler model and uses it for feature
+# engineering.
+#
+# Before executing this program, please make sure you have followed Flink ML's
+# quick start guideline to set up Flink ML and Flink environment. The guideline
+# can be found at
+#
+# https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.feature.maxabsscaler import MaxAbsScaler
+from pyflink.table import StreamTableEnvironment
+
+# create a new StreamExecutionEnvironment
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# create a StreamTableEnvironment
+t_env = StreamTableEnvironment.create(env)
+
+# generate input training and prediction data
+train_data = t_env.from_data_stream(
+ env.from_collection([
+ (Vectors.dense(0.0, 3.0),),
+ (Vectors.dense(2.1, 0.0),),
+ (Vectors.dense(4.1, 5.1),),
+ (Vectors.dense(6.1, 8.1),),
+ (Vectors.dense(200, 400),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input'],
+ [DenseVectorTypeInfo()])
+ ))
+
+predict_data = t_env.from_data_stream(
+ env.from_collection([
+ (Vectors.dense(150.0, 90.0),),
+ (Vectors.dense(50.0, 40.0),),
+ (Vectors.dense(100.0, 50.0),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input'],
+ [DenseVectorTypeInfo()])
+ ))
+
+# create a maxabs scaler object and initialize its parameters
+max_abs_scaler = MaxAbsScaler()
+
+# train the maxabs scaler model
+model = max_abs_scaler.fit(train_data)
+
+# use the maxabs scaler model for predictions
+output = model.transform(predict_data)[0]
+
+# extract and display the results
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ input_value = result[field_names.index(max_abs_scaler.get_input_col())]
+ output_value = result[field_names.index(max_abs_scaler.get_output_col())]
+ print('Input Value: ' + str(input_value) + ' \tOutput Value: ' + str(output_value))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/maxabsscaler.py b/flink-ml-python/pyflink/ml/lib/feature/maxabsscaler.py
new file mode 100644
index 0000000..67e1246
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/maxabsscaler.py
@@ -0,0 +1,71 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator
+from pyflink.ml.lib.param import HasInputCol, HasOutputCol
+
+
+class _MaxAbsScalerParams(
+ JavaWithParams,
+ HasInputCol,
+ HasOutputCol
+):
+
+ def __init__(self, java_params):
+ super(_MaxAbsScalerParams, self).__init__(java_params)
+
+
+class MaxAbsScalerModel(JavaFeatureModel, _MaxAbsScalerParams):
+ """
+ A Model which transforms data using the model data computed by :class:`MaxAbsScaler`.
+ """
+
+ def __init__(self, java_model=None):
+ super(MaxAbsScalerModel, self).__init__(java_model)
+
+ @classmethod
+ def _java_model_package_name(cls) -> str:
+ return "maxabsscaler"
+
+ @classmethod
+ def _java_model_class_name(cls) -> str:
+ return "MaxAbsScalerModel"
+
+
+class MaxAbsScaler(JavaFeatureEstimator, _MaxAbsScalerParams):
+ """
+ An Estimator which implements the MaxAbsScaler algorithm. This algorithm rescales feature
+ values to the range [-1, 1] by dividing through the largest maximum absolute value in each
+ feature. It does not shift/center the data and thus does not destroy any sparsity.
+ """
+
+ def __init__(self):
+ super(MaxAbsScaler, self).__init__()
+
+ @classmethod
+ def _create_model(cls, java_model) -> MaxAbsScalerModel:
+ return MaxAbsScalerModel(java_model)
+
+ @classmethod
+ def _java_estimator_package_name(cls) -> str:
+ return "maxabsscaler"
+
+ @classmethod
+ def _java_estimator_class_name(cls) -> str:
+ return "MaxAbsScaler"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/minmaxscaler.py b/flink-ml-python/pyflink/ml/lib/feature/minmaxscaler.py
index 3db5234..d02a760 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/minmaxscaler.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/minmaxscaler.py
@@ -70,8 +70,7 @@ class _MinMaxScalerParams(
class MinMaxScalerModel(JavaFeatureModel, _MinMaxScalerParams):
"""
- * A Model which do a minMax scaler operation using the model data computed
- by :class:`MinMaxScaler`.
+ A Model which transforms data using the model data computed by :class:`MinMaxScaler`.
"""
def __init__(self, java_model=None):
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_maxabsscaler.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_maxabsscaler.py
new file mode 100644
index 0000000..925ae4a
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_maxabsscaler.py
@@ -0,0 +1,99 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from typing import List
+
+from pyflink.common import Types
+from pyflink.table import Table
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo, DenseVector
+from pyflink.ml.lib.feature.maxabsscaler import MaxAbsScaler
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class MaxAbsScalerTest(PyFlinkMLTestCase):
+ def setUp(self):
+ super(MaxAbsScalerTest, self).setUp()
+ self.train_data = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (Vectors.dense([0.0, 3.0]),),
+ (Vectors.dense([2.1, 0.0]),),
+ (Vectors.dense([4.1, 5.1]),),
+ (Vectors.dense([6.1, 8.1]),),
+ (Vectors.dense([200., 400.]),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input'],
+ [DenseVectorTypeInfo()])))
+
+ self.predict_data = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (Vectors.dense([150.0, 90.0]),),
+ (Vectors.dense([50.0, 40.0]),),
+ (Vectors.dense([100.0, 50.0]),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input'],
+ [DenseVectorTypeInfo()])))
+ self.expected_data = [
+ Vectors.dense(0.25, 0.1),
+ Vectors.dense(0.5, 0.125),
+ Vectors.dense(0.75, 0.225)]
+
+ def test_param(self):
+ max_abs_scalar = MaxAbsScaler()
+ self.assertEqual("input", max_abs_scalar.input_col)
+ self.assertEqual("output", max_abs_scalar.output_col)
+ max_abs_scalar.set_input_col('test_input') \
+ .set_output_col('test_output')
+ self.assertEqual('test_input', max_abs_scalar.input_col)
+ self.assertEqual('test_output', max_abs_scalar.output_col)
+
+ def test_output_schema(self):
+ max_abs_scalar = MaxAbsScaler() \
+ .set_input_col('test_input') \
+ .set_output_col('test_output')
+
+ model = max_abs_scalar.fit(self.train_data.alias('test_input'))
+ output = model.transform(self.predict_data.alias('test_input'))[0]
+ self.assertEqual(
+ ['test_input', 'test_output'],
+ output.get_schema().get_field_names())
+
+ def test_fit_and_predict(self):
+ max_abs_scalar = MaxAbsScaler()
+ model = max_abs_scalar.fit(self.train_data)
+ output = model.transform(self.predict_data)[0]
+ self.verify_output_result(
+ output,
+ max_abs_scalar.get_output_col(),
+ output.get_schema().get_field_names(),
+ self.expected_data)
+
+ def verify_output_result(
+ self, output: Table,
+ output_col: str,
+ field_names: List[str],
+ expected_result: List[DenseVector]):
+ collected_results = [result for result in
+ self.t_env.to_data_stream(output).execute_and_collect()]
+ results = []
+ for item in collected_results:
+ item.set_field_names(field_names)
+ results.append(item[output_col])
+ results.sort(key=lambda x: x[0])
+ self.assertEqual(expected_result, results)