You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/03/29 08:27:27 UTC

[GitHub] [flink-ml] lindong28 commented on a change in pull request #73: [FLINK-26626] Add Transformer and Estimator for StandardScaler

lindong28 commented on a change in pull request #73:
URL: https://github.com/apache/flink-ml/pull/73#discussion_r837111842



##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
##########
@@ -32,9 +32,31 @@ public static double asum(DenseVector x) {
     }
 
     /** y += a * x . */
-    public static void axpy(double a, DenseVector x, DenseVector y) {
+    public static void axpy(double a, Vector x, DenseVector y) {
         Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched.");
-        JAVA_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1);
+        if (x instanceof SparseVector) {
+            axpy(a, (SparseVector) x, y);
+        } else {
+            axpy(a, (DenseVector) x, y);
+        }
+    }
+
+    /** Computes the hadamard product of the two vectors (y = y \hdot x). */
+    public static void hDot(Vector x, Vector y) {
+        Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched.");
+        if (y instanceof DenseVector) {
+            if (x instanceof SparseVector) {

Review comment:
       nits: it seems that we use both `instanceof DenseVector` and `instanceof SparseVector` in this class. Would it be slightly better to consistently use one of the two in this class?

##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
##########
@@ -89,4 +111,56 @@ public static void gemv(
                 y.values,
                 1);
     }
+
+    private static void axpy(double a, DenseVector x, DenseVector y) {
+        JAVA_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1);
+    }
+
+    private static void axpy(double a, SparseVector x, DenseVector y) {
+        for (int i = 0; i < x.indices.length; i++) {
+            int index = x.indices[i];
+            y.values[index] += a * x.values[i];
+        }
+    }
+
+    private static void hDot(SparseVector x, SparseVector y) {
+        int idx = 0;
+        int idy = 0;
+        while (idx < x.indices.length && idy < y.indices.length) {
+            int indexX = x.indices[idx];
+            while (idy < y.indices.length && y.indices[idy] < indexX) {
+                y.values[idy] = 0;
+                idy++;
+            }
+            if (idy < y.indices.length && y.indices[idy] == indexX) {
+                y.values[idy] *= x.values[idx];
+                idy++;
+            }
+            idx++;
+        }
+    }
+
+    private static void hDot(SparseVector x, DenseVector y) {
+        int idx = 0;
+        for (int i = 0; i < y.size(); i++) {
+            if (x.indices[idx] == i) {

Review comment:
       Should we use the following check:
   
   `if (idx < x.indices.length && x.indices[idx] == i)`

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java
##########
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.standardscaler;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the standard scaling algorithm.
+ *
+ * <p>Standardization is a common requirement for machine learning training because they may behave
+ * badly if the individual features of a input do not look like standard normally distributed data
+ * (e.g. Gaussian with 0 mean and unit variance).
+ *
+ * <p>This estimator standardizes the input features by removing the mean and scaling each dimension
+ * to unit variance.
+ */
+public class StandardScaler
+        implements Estimator<StandardScaler, StandardScalerModel>,
+                StandardScalerParams<StandardScaler> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public StandardScaler() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public StandardScalerModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<DenseVector, DenseVector, Long>> sumAndSquaredSumAndWeight =
+                tEnv.toDataStream(inputs[0])
+                        .transform(
+                                "computeMeta",
+                                new TupleTypeInfo<>(
+                                        TypeInformation.of(DenseVector.class),
+                                        TypeInformation.of(DenseVector.class),
+                                        BasicTypeInfo.LONG_TYPE_INFO),
+                                new ComputeMetaOperator(getFeaturesCol()));
+
+        DataStream<StandardScalerModelData> modelData =
+                sumAndSquaredSumAndWeight
+                        .transform(
+                                "buildModel",
+                                TypeInformation.of(StandardScalerModelData.class),
+                                new BuildModelOperator())
+                        .setParallelism(1);
+
+        StandardScalerModel model =
+                new StandardScalerModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /**
+     * Builds the {@link StandardScalerModelData} using the meta data computed on each partition.
+     */
+    private static class BuildModelOperator extends AbstractStreamOperator<StandardScalerModelData>
+            implements OneInputStreamOperator<
+                            Tuple3<DenseVector, DenseVector, Long>, StandardScalerModelData>,
+                    BoundedOneInput {
+        private ListState<DenseVector> sumState;
+        private ListState<DenseVector> squaredSumState;
+        private ListState<Long> numElementsState;
+        private DenseVector sum;
+        private DenseVector squaredSum;
+        private long numElements;
+
+        @Override
+        public void endInput() {
+            if (numElements > 0) {
+                BLAS.scal(1.0 / numElements, sum);
+                double[] mean = sum.values;
+                double[] std = squaredSum.values;
+                if (numElements > 1) {
+                    for (int i = 0; i < mean.length; i++) {
+                        std[i] =
+                                Math.sqrt(
+                                        (squaredSum.values[i] - numElements * mean[i] * mean[i])
+                                                / (numElements - 1));
+                    }
+                } else {
+                    Arrays.fill(std, 0.0);
+                }
+
+                output.collect(
+                        new StreamRecord<>(
+                                new StandardScalerModelData(
+                                        Vectors.dense(mean), Vectors.dense(std))));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<DenseVector, DenseVector, Long>> element) {
+            Tuple3<DenseVector, DenseVector, Long> value = element.getValue();
+            if (sum == null) {
+                sum = value.f0;
+                squaredSum = value.f1;
+                numElements = value.f2;
+            } else {
+                BLAS.axpy(1, value.f0, sum);
+                BLAS.axpy(1, value.f1, squaredSum);
+                numElements += value.f2;
+            }
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            sumState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "sumState", TypeInformation.of(DenseVector.class)));
+            squaredSumState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "squaredSumState",
+                                            TypeInformation.of(DenseVector.class)));
+            numElementsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "numElementsState", BasicTypeInfo.LONG_TYPE_INFO));
+
+            sum = OperatorStateUtils.getUniqueElement(sumState, "sumState").orElse(null);
+            squaredSum =
+                    OperatorStateUtils.getUniqueElement(squaredSumState, "squaredSumState")
+                            .orElse(null);
+            numElements =
+                    OperatorStateUtils.getUniqueElement(numElementsState, "numElementsState")
+                            .orElse(0L);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            if (numElements > 0) {
+                sumState.update(Collections.singletonList(sum));
+                squaredSumState.update(Collections.singletonList(squaredSum));
+                numElementsState.update(Collections.singletonList(numElements));
+            }
+        }
+    }
+
+    /** Computes sum, squared sum and number of elements in each partition. */
+    private static class ComputeMetaOperator
+            extends AbstractStreamOperator<Tuple3<DenseVector, DenseVector, Long>>
+            implements OneInputStreamOperator<Row, Tuple3<DenseVector, DenseVector, Long>>,
+                    BoundedOneInput {
+        private ListState<DenseVector> sumState;
+        private ListState<DenseVector> squaredSumState;
+        private ListState<Long> numElementsState;
+        private DenseVector sum;
+        private DenseVector squaredSum;
+        private long numElements;
+
+        private final String featuresCol;
+
+        public ComputeMetaOperator(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void endInput() {
+            if (numElements > 0) {
+                output.collect(new StreamRecord<>(Tuple3.of(sum, squaredSum, numElements)));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> element) throws Exception {
+            Vector feature = (Vector) element.getValue().getField(featuresCol);
+            if (sum == null) {
+                sum = new DenseVector(feature.size());
+                squaredSum = new DenseVector(feature.size());
+            }
+            BLAS.axpy(1, feature, sum);
+            BLAS.hDot(feature, feature);
+            BLAS.axpy(1, feature, squaredSum);
+            numElements++;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            sumState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "sumState", TypeInformation.of(DenseVector.class)));
+            squaredSumState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "squaredSumState",
+                                            TypeInformation.of(DenseVector.class)));
+            numElementsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "numElementsState", BasicTypeInfo.LONG_TYPE_INFO));
+
+            sum = OperatorStateUtils.getUniqueElement(sumState, "sumState").orElse(null);
+            squaredSum =
+                    OperatorStateUtils.getUniqueElement(squaredSumState, "squaredSumState")
+                            .orElse(null);
+            numElements =
+                    OperatorStateUtils.getUniqueElement(numElementsState, "numElementsState")
+                            .orElse(0L);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            if (numElements > 0) {

Review comment:
       This code aims to handle the scenario where `fit(...)` is called with an empty table. Under this scenario, the current implementation would generate an empty modelDataTable.
   
   Should we instead generate a modelTable with one element, whose value indicates that the model data can not be used for inference?
   
   By emitting a model data table with this value, model.transform(...) could throw proper exception to indicate that it can not do inference. In comparison, if we generate a model data table with no value, model.transform(...) would have to block forever, since it will not be able to differentiate this scenario from the scenario where source is taking a long time to read the model data.
   
   Same for other usages of `if (numElements > 0)` in this PR.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerModel.java
##########
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.standardscaler;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which transforms data using the model data computed by {@link StandardScaler}. */
+public class StandardScalerModel
+        implements Model<StandardScalerModel>, StandardScalerParams<StandardScalerModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public StandardScalerModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked, rawtypes")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Row> inputStream = tEnv.toDataStream(inputs[0]);
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), TypeInformation.of(Vector.class)),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol()));
+
+        final String broadcastModelKey = "broadcastModelKey";
+        DataStream<StandardScalerModelData> modelDataStream =
+                StandardScalerModelData.getModelDataStream(modelDataTable);
+
+        DataStream<Row> predictionResult =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(inputStream),
+                        Collections.singletonMap(broadcastModelKey, modelDataStream),
+                        inputList -> {
+                            DataStream inputData = inputList.get(0);
+                            return inputData.map(
+                                    new PredictOutputFunction(
+                                            broadcastModelKey,
+                                            getFeaturesCol(),
+                                            getWithMean(),
+                                            getWithStd()),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+        private final String broadcastModelKey;
+        private final String featuresCol;
+        private final boolean withMean;
+        private final boolean withStd;
+        private DenseVector mean;
+        private DenseVector scale;
+
+        public PredictOutputFunction(
+                String broadcastModelKey, String featuresCol, boolean withMean, boolean withStd) {
+            this.broadcastModelKey = broadcastModelKey;
+            this.featuresCol = featuresCol;
+            this.withMean = withMean;
+            this.withStd = withStd;
+        }
+
+        @Override
+        public Row map(Row dataPoint) {
+            if (mean == null) {
+                StandardScalerModelData modelData =
+                        (StandardScalerModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
+                mean = modelData.mean;
+                DenseVector std = modelData.std;
+
+                if (withStd) {
+                    scale = std;
+                    double[] scaleValues = scale.values;
+                    for (int i = 0; i < scaleValues.length; i++) {
+                        scaleValues[i] = scaleValues[i] == 0 ? 0 : 1 / scaleValues[i];
+                    }
+                }
+            }
+
+            Vector feature = (Vector) (dataPoint.getField(featuresCol));
+            Vector output;
+            if (feature instanceof DenseVector) {

Review comment:
       Would it be simpler to add the method `Vector::clone()` similar to Spark's `Vector::copy()`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScalerParams.java
##########
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.standardscaler;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.Param;
+
+/**
+ * Params for {@link StandardScaler}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface StandardScalerParams<T> extends HasFeaturesCol<T>, HasPredictionCol<T> {
+    Param<Boolean> WITH_MEAN =
+            new BooleanParam(
+                    "withMean", "Whether centers the data with mean before scaling.", false);
+
+    default Boolean getWithMean() {
+        return get(WITH_MEAN);
+    }
+
+    default T setWithMean(boolean withMean) {
+        return set(WITH_MEAN, withMean);
+    }
+
+    Param<Boolean> WITH_STD =

Review comment:
       Could we move the variable declaration to be before the method declarations, similar to other Java classes in Flink ML?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java
##########
@@ -172,11 +173,9 @@ public Row map(Row row) {
                 }
             }
             DenseVector feature = (DenseVector) row.getField(featureCol);
-            DenseVector outputVector = new DenseVector(scaleVector.size());
-            for (int i = 0; i < scaleVector.size(); ++i) {
-                outputVector.values[i] =
-                        feature.values[i] * scaleVector.values[i] + offsetVector.values[i];
-            }
+            DenseVector outputVector = feature.clone();
+            BLAS.hDot(scaleVector, outputVector);
+            BLAS.axpy(1, offsetVector, outputVector);

Review comment:
       The new implementation seems to be strictly slower than the previous implementation. Should we keep the previous implementation?
   
   The previous implementation just needs one for loop over the the `outputVector`.
   
   The new implementation needs two for loops (i.e. `clone()` and `hDot()`) over the `outputVector`, plus one `JAVA_BLAS.daxpy()` call.

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/standardscaler/StandardScaler.java
##########
@@ -0,0 +1,288 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.standardscaler;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the standard scaling algorithm.
+ *
+ * <p>Standardization is a common requirement for machine learning training because they may behave
+ * badly if the individual features of a input do not look like standard normally distributed data
+ * (e.g. Gaussian with 0 mean and unit variance).
+ *
+ * <p>This estimator standardizes the input features by removing the mean and scaling each dimension
+ * to unit variance.
+ */
+public class StandardScaler
+        implements Estimator<StandardScaler, StandardScalerModel>,
+                StandardScalerParams<StandardScaler> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public StandardScaler() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public StandardScalerModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Tuple3<DenseVector, DenseVector, Long>> sumAndSquaredSumAndWeight =
+                tEnv.toDataStream(inputs[0])
+                        .transform(
+                                "computeMeta",
+                                new TupleTypeInfo<>(
+                                        TypeInformation.of(DenseVector.class),
+                                        TypeInformation.of(DenseVector.class),
+                                        BasicTypeInfo.LONG_TYPE_INFO),
+                                new ComputeMetaOperator(getFeaturesCol()));
+
+        DataStream<StandardScalerModelData> modelData =
+                sumAndSquaredSumAndWeight
+                        .transform(
+                                "buildModel",
+                                TypeInformation.of(StandardScalerModelData.class),
+                                new BuildModelOperator())
+                        .setParallelism(1);
+
+        StandardScalerModel model =
+                new StandardScalerModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /**
+     * Builds the {@link StandardScalerModelData} using the meta data computed on each partition.
+     */
+    private static class BuildModelOperator extends AbstractStreamOperator<StandardScalerModelData>
+            implements OneInputStreamOperator<
+                            Tuple3<DenseVector, DenseVector, Long>, StandardScalerModelData>,
+                    BoundedOneInput {
+        private ListState<DenseVector> sumState;
+        private ListState<DenseVector> squaredSumState;
+        private ListState<Long> numElementsState;
+        private DenseVector sum;
+        private DenseVector squaredSum;
+        private long numElements;
+
+        @Override
+        public void endInput() {
+            if (numElements > 0) {
+                BLAS.scal(1.0 / numElements, sum);
+                double[] mean = sum.values;
+                double[] std = squaredSum.values;
+                if (numElements > 1) {
+                    for (int i = 0; i < mean.length; i++) {
+                        std[i] =
+                                Math.sqrt(
+                                        (squaredSum.values[i] - numElements * mean[i] * mean[i])
+                                                / (numElements - 1));
+                    }
+                } else {
+                    Arrays.fill(std, 0.0);
+                }
+
+                output.collect(
+                        new StreamRecord<>(
+                                new StandardScalerModelData(
+                                        Vectors.dense(mean), Vectors.dense(std))));
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<DenseVector, DenseVector, Long>> element) {
+            Tuple3<DenseVector, DenseVector, Long> value = element.getValue();
+            if (sum == null) {
+                sum = value.f0;
+                squaredSum = value.f1;
+                numElements = value.f2;
+            } else {
+                BLAS.axpy(1, value.f0, sum);
+                BLAS.axpy(1, value.f1, squaredSum);
+                numElements += value.f2;
+            }
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            sumState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "sumState", TypeInformation.of(DenseVector.class)));
+            squaredSumState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "squaredSumState",
+                                            TypeInformation.of(DenseVector.class)));
+            numElementsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "numElementsState", BasicTypeInfo.LONG_TYPE_INFO));
+
+            sum = OperatorStateUtils.getUniqueElement(sumState, "sumState").orElse(null);
+            squaredSum =
+                    OperatorStateUtils.getUniqueElement(squaredSumState, "squaredSumState")
+                            .orElse(null);
+            numElements =
+                    OperatorStateUtils.getUniqueElement(numElementsState, "numElementsState")
+                            .orElse(0L);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            if (numElements > 0) {
+                sumState.update(Collections.singletonList(sum));
+                squaredSumState.update(Collections.singletonList(squaredSum));
+                numElementsState.update(Collections.singletonList(numElements));
+            }
+        }
+    }
+
+    /** Computes sum, squared sum and number of elements in each partition. */
+    private static class ComputeMetaOperator
+            extends AbstractStreamOperator<Tuple3<DenseVector, DenseVector, Long>>
+            implements OneInputStreamOperator<Row, Tuple3<DenseVector, DenseVector, Long>>,
+                    BoundedOneInput {
+        private ListState<DenseVector> sumState;
+        private ListState<DenseVector> squaredSumState;
+        private ListState<Long> numElementsState;
+        private DenseVector sum;
+        private DenseVector squaredSum;
+        private long numElements;
+
+        private final String featuresCol;
+
+        public ComputeMetaOperator(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void endInput() {
+            if (numElements > 0) {

Review comment:
       Would it be simpler to remove this check?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org