You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@flink.apache.org by GitBox <gi...@apache.org> on 2022/10/28 05:40:14 UTC

[GitHub] [flink-ml] yunfengzhou-hub commented on a diff in pull request #166: [FLINK-29598] Add Estimator and Transformer for Imputer

yunfengzhou-hub commented on code in PR #166:
URL: https://github.com/apache/flink-ml/pull/166#discussion_r1007561913


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerParams.java:
##########
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.imputer;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link Imputer}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface ImputerParams<T> extends ImputerModelParams<T> {

Review Comment:
   Spark's `Imputer` also has a `relativeError` parameter. Should we also add this parameter to Flink ML?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/param/DoubleParam.java:
##########
@@ -32,4 +34,12 @@ public DoubleParam(
     public DoubleParam(String name, String description, Double defaultValue) {
         this(name, description, defaultValue, ParamValidators.alwaysTrue());
     }
+
+    @Override
+    public Double jsonDecode(Object json) throws IOException {
+        if (json instanceof String && json.equals(String.valueOf(Double.NaN))) {
+            return Double.NaN;
+        }
+        return (Double) json;

Review Comment:
   How about the following implementation?
   ```java
   if (json instanceof String) {
       return Double.valueOf((String) json);
   }
   return (Double) json;
   ```
   This applies to `Double.NaN` as well as other special values.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java:
##########
@@ -0,0 +1,328 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.imputer;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.util.QuantileSummary;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * The Imputer estimator completes missing values in a dataset. Missing values can be imputed using

Review Comment:
   nit: it might be better to use "bounded stream" instead of "dataset", as dataset has a specific meaning in flink.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerModel.java:
##########
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.imputer;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which replace the missing values using the model data computed by {@link Imputer}. */
+public class ImputerModel implements Model<ImputerModel>, ImputerModelParams<ImputerModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public ImputerModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public ImputerModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                ImputerModelData.getModelDataStream(modelDataTable),
+                path,
+                new ImputerModelData.ModelDataEncoder());
+    }
+
+    public static ImputerModel load(StreamTableEnvironment tEnv, String path) throws IOException {
+        ImputerModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(tEnv, path, new ImputerModelData.ModelDataDecoder());
+        return model.setModelData(modelDataTable);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        String[] inputCols = getInputCols();
+        String[] outputCols = getOutputCols();
+        Preconditions.checkArgument(inputCols.length == outputCols.length);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+        DataStream<ImputerModelData> imputerModel =
+                ImputerModelData.getModelDataStream(modelDataTable);
+
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        TypeInformation<?>[] outputTypes = new TypeInformation[outputCols.length];
+        Arrays.fill(outputTypes, BasicTypeInfo.DOUBLE_TYPE_INFO);
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputTypes),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCols));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(dataStream),
+                        Collections.singletonMap(broadcastModelKey, imputerModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictOutputFunction(
+                                            getMissingValue(), inputCols, broadcastModelKey),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+
+        private final String[] inputCols;
+        private final String broadcastKey;
+        private final double missingValue;
+        private Map<String, Double> surrogates;
+
+        public PredictOutputFunction(double missingValue, String[] inputCols, String broadcastKey) {
+            this.missingValue = missingValue;
+            this.inputCols = inputCols;
+            this.broadcastKey = broadcastKey;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            if (surrogates == null) {
+                ImputerModelData imputerModelData =
+                        (ImputerModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+                surrogates = imputerModelData.surrogates;
+                Preconditions.checkArgument(
+                        surrogates.size() == inputCols.length,
+                        "Imputer is expecting %s input columns as input but only got %s.",
+                        surrogates.size(),
+                        inputCols.length);
+            }
+
+            Row outputRow = new Row(inputCols.length);
+            for (int i = 0; i < inputCols.length; i++) {
+                Object value = row.getField(i);

Review Comment:
   nit: `Double value = row.getFieldAs(i)` might be better.



##########
flink-ml-core/src/main/java/org/apache/flink/ml/param/DoubleParam.java:
##########
@@ -32,4 +34,12 @@ public DoubleParam(
     public DoubleParam(String name, String description, Double defaultValue) {
         this(name, description, defaultValue, ParamValidators.alwaysTrue());
     }
+
+    @Override
+    public Double jsonDecode(Object json) throws IOException {

Review Comment:
   IOException can be removed.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerModel.java:
##########
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.imputer;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which replace the missing values using the model data computed by {@link Imputer}. */
+public class ImputerModel implements Model<ImputerModel>, ImputerModelParams<ImputerModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public ImputerModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public ImputerModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                ImputerModelData.getModelDataStream(modelDataTable),
+                path,
+                new ImputerModelData.ModelDataEncoder());
+    }
+
+    public static ImputerModel load(StreamTableEnvironment tEnv, String path) throws IOException {
+        ImputerModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(tEnv, path, new ImputerModelData.ModelDataDecoder());
+        return model.setModelData(modelDataTable);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        String[] inputCols = getInputCols();
+        String[] outputCols = getOutputCols();
+        Preconditions.checkArgument(inputCols.length == outputCols.length);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+        DataStream<ImputerModelData> imputerModel =
+                ImputerModelData.getModelDataStream(modelDataTable);
+
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        TypeInformation<?>[] outputTypes = new TypeInformation[outputCols.length];
+        Arrays.fill(outputTypes, BasicTypeInfo.DOUBLE_TYPE_INFO);
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputTypes),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCols));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(dataStream),
+                        Collections.singletonMap(broadcastModelKey, imputerModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictOutputFunction(
+                                            getMissingValue(), inputCols, broadcastModelKey),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+
+        private final String[] inputCols;
+        private final String broadcastKey;
+        private final double missingValue;
+        private Map<String, Double> surrogates;
+
+        public PredictOutputFunction(double missingValue, String[] inputCols, String broadcastKey) {
+            this.missingValue = missingValue;
+            this.inputCols = inputCols;
+            this.broadcastKey = broadcastKey;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            if (surrogates == null) {
+                ImputerModelData imputerModelData =
+                        (ImputerModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+                surrogates = imputerModelData.surrogates;
+                Preconditions.checkArgument(
+                        surrogates.size() == inputCols.length,
+                        "Imputer is expecting %s input columns as input but only got %s.",

Review Comment:
   I think it should be acceptable so long as `inputCols` is a subset of `imputerModelData.surrogates`. They do not need to be equal.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@flink.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org