You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/06/20 03:43:18 UTC

[flink-ml] 01/04: [FLINK-27877] Improve performance for StringIndexer

This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit ba7760792827634664bb1962cc7ca6e9c161f255
Author: zhangzp <zh...@gmail.com>
AuthorDate: Mon Jun 20 09:42:17 2022 +0800

    [FLINK-27877] Improve performance for StringIndexer
---
 .../ml/feature/stringindexer/StringIndexer.java    | 219 +++++++++++----------
 1 file changed, 117 insertions(+), 102 deletions(-)

diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
index ee560e0..c8312fa 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
@@ -18,31 +18,40 @@
 
 package org.apache.flink.ml.feature.stringindexer;
 
-import org.apache.flink.api.common.functions.FlatMapFunction;
-import org.apache.flink.api.common.functions.MapPartitionFunction;
-import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeHint;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
 import org.apache.flink.ml.api.Estimator;
 import org.apache.flink.ml.common.datastream.DataStreamUtils;
 import org.apache.flink.ml.common.param.HasHandleInvalid;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 import org.apache.flink.types.Row;
-import org.apache.flink.util.Collector;
 import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 
 /**
  * An Estimator which implements the string indexing algorithm.
@@ -91,19 +100,34 @@ public class StringIndexer
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
 
-        DataStream<Tuple2<Integer, String>> columnIdAndString =
-                tEnv.toDataStream(inputs[0]).flatMap(new ExtractColumnIdAndString(inputCols));
+        DataStream<HashMap<String, Long>[]> localCountedString =
+                tEnv.toDataStream(inputs[0])
+                        .transform(
+                                "countStringOperator",
+                                TypeInformation.of(new TypeHint<HashMap<String, Long>[]>() {}),
+                                new CountStringOperator(inputCols));
 
-        DataStream<Tuple3<Integer, String, Long>> columnIdAndStringAndCnt =
-                DataStreamUtils.mapPartition(
-                        columnIdAndString.keyBy(
-                                (KeySelector<Tuple2<Integer, String>, Integer>) Tuple2::hashCode),
-                        new CountStringsByColumn(inputCols.length));
+        DataStream<HashMap<String, Long>[]> countedString =
+                DataStreamUtils.reduce(
+                        localCountedString,
+                        (ReduceFunction<HashMap<String, Long>[]>)
+                                (value1, value2) -> {
+                                    for (int i = 0; i < value1.length; i++) {
+                                        for (Entry<String, Long> stringAndCnt :
+                                                value2[i].entrySet()) {
+                                            value1[i].compute(
+                                                    stringAndCnt.getKey(),
+                                                    (k, v) ->
+                                                            (v == null
+                                                                    ? stringAndCnt.getValue()
+                                                                    : v + stringAndCnt.getValue()));
+                                        }
+                                    }
+                                    return value1;
+                                });
 
         DataStream<StringIndexerModelData> modelData =
-                DataStreamUtils.mapPartition(
-                        columnIdAndStringAndCnt,
-                        new GenerateModel(inputCols.length, getStringOrderType()));
+                countedString.map(new ModelGenerator(getStringOrderType()));
         modelData.getTransformation().setParallelism(1);
 
         StringIndexerModel model =
@@ -112,38 +136,93 @@ public class StringIndexer
         return model;
     }
 
+    /** Computes the occurrence time of each string by columns. */
+    private static class CountStringOperator extends AbstractStreamOperator<HashMap<String, Long>[]>
+            implements OneInputStreamOperator<Row, HashMap<String, Long>[]>, BoundedOneInput {
+        /** The name of input columns. */
+        private final String[] inputCols;
+        /** The occurrence time of each string by column. */
+        private HashMap<String, Long>[] stringCntByColumn;
+        /** The state of stringCntByColumn. */
+        private ListState<HashMap<String, Long>[]> stringCntByColumnState;
+
+        public CountStringOperator(String[] inputCols) {
+            this.inputCols = inputCols;
+            stringCntByColumn = new HashMap[inputCols.length];
+            for (int i = 0; i < stringCntByColumn.length; i++) {
+                stringCntByColumn[i] = new HashMap<>();
+            }
+        }
+
+        @Override
+        public void endInput() {
+            output.collect(new StreamRecord<>(stringCntByColumn));
+            stringCntByColumnState.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> element) {
+            Row r = element.getValue();
+            for (int i = 0; i < inputCols.length; i++) {
+                Object objVal = r.getField(inputCols[i]);
+                String stringVal;
+                if (objVal instanceof String) {
+                    stringVal = (String) objVal;
+                } else if (objVal instanceof Number) {
+                    stringVal = String.valueOf(objVal);
+                } else {
+                    throw new RuntimeException(
+                            "The input column only supports string and numeric type.");
+                }
+                stringCntByColumn[i].compute(stringVal, (k, v) -> (v == null ? 1 : v + 1));
+            }
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            stringCntByColumnState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "stringCntByColumnState",
+                                            TypeInformation.of(
+                                                    new TypeHint<HashMap<String, Long>[]>() {})));
+
+            OperatorStateUtils.getUniqueElement(stringCntByColumnState, "stringCntByColumnState")
+                    .ifPresent(x -> stringCntByColumn = x);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            stringCntByColumnState.update(Collections.singletonList(stringCntByColumn));
+        }
+    }
+
     /**
      * Merges all the extracted strings and generates the {@link StringIndexerModelData} according
      * to the specified string order type.
      */
-    private static class GenerateModel
-            implements MapPartitionFunction<Tuple3<Integer, String, Long>, StringIndexerModelData> {
-        private final int numCols;
+    private static class ModelGenerator
+            implements MapFunction<HashMap<String, Long>[], StringIndexerModelData> {
         private final String stringOrderType;
 
-        public GenerateModel(int numCols, String stringOrderType) {
-            this.numCols = numCols;
+        public ModelGenerator(String stringOrderType) {
             this.stringOrderType = stringOrderType;
         }
 
         @Override
-        @SuppressWarnings("unchecked")
-        public void mapPartition(
-                Iterable<Tuple3<Integer, String, Long>> values,
-                Collector<StringIndexerModelData> out) {
+        public StringIndexerModelData map(HashMap<String, Long>[] value) {
+            int numCols = value.length;
             String[][] stringArrays = new String[numCols][];
-            ArrayList<Tuple2<String, Long>>[] stringsAndCntsByColumn = new ArrayList[numCols];
+            ArrayList<Tuple2<String, Long>> stringsAndCnts = new ArrayList<>();
             for (int i = 0; i < numCols; i++) {
-                stringsAndCntsByColumn[i] = new ArrayList<>();
-            }
-
-            for (Tuple3<Integer, String, Long> colIdAndStringAndCnt : values) {
-                stringsAndCntsByColumn[colIdAndStringAndCnt.f0].add(
-                        Tuple2.of(colIdAndStringAndCnt.f1, colIdAndStringAndCnt.f2));
-            }
-
-            for (int i = 0; i < stringsAndCntsByColumn.length; i++) {
-                List<Tuple2<String, Long>> stringsAndCnts = stringsAndCntsByColumn[i];
+                stringsAndCnts.clear();
+                stringsAndCnts.ensureCapacity(value[i].size());
+                for (Map.Entry<String, Long> entry : value[i].entrySet()) {
+                    stringsAndCnts.add(Tuple2.of(entry.getKey(), entry.getValue()));
+                }
                 switch (stringOrderType) {
                     case ALPHABET_ASC_ORDER:
                         stringsAndCnts.sort(Comparator.comparing(valAndCnt -> valAndCnt.f0));
@@ -171,74 +250,10 @@ public class StringIndexer
                                         + stringOrderType
                                         + ".");
                 }
-
-                stringArrays[i] = new String[stringsAndCnts.size()];
-                for (int stringId = 0; stringId < stringArrays[i].length; stringId++) {
-                    stringArrays[i][stringId] = stringsAndCnts.get(stringId).f0;
-                }
-            }
-
-            out.collect(new StringIndexerModelData(stringArrays));
-        }
-    }
-
-    /** Computes the frequency of strings in each column. */
-    private static class CountStringsByColumn
-            implements MapPartitionFunction<
-                    Tuple2<Integer, String>, Tuple3<Integer, String, Long>> {
-        private final int numCols;
-
-        public CountStringsByColumn(int numCols) {
-            this.numCols = numCols;
-        }
-
-        @Override
-        @SuppressWarnings("unchecked")
-        public void mapPartition(
-                Iterable<Tuple2<Integer, String>> values,
-                Collector<Tuple3<Integer, String, Long>> out) {
-            HashMap<String, Long>[] string2CntByColumn = new HashMap[numCols];
-            for (int i = 0; i < numCols; i++) {
-                string2CntByColumn[i] = new HashMap<>();
+                stringArrays[i] = stringsAndCnts.stream().map(x -> x.f0).toArray(String[]::new);
             }
-            for (Tuple2<Integer, String> columnIdAndString : values) {
-                int colId = columnIdAndString.f0;
-                String stringVal = columnIdAndString.f1;
-                long cnt = string2CntByColumn[colId].getOrDefault(stringVal, 0L) + 1;
-                string2CntByColumn[colId].put(stringVal, cnt);
-            }
-            for (int i = 0; i < numCols; i++) {
-                for (Map.Entry<String, Long> entry : string2CntByColumn[i].entrySet()) {
-                    out.collect(Tuple3.of(i, entry.getKey(), entry.getValue()));
-                }
-            }
-        }
-    }
 
-    /** Extracts strings by column. */
-    private static class ExtractColumnIdAndString
-            implements FlatMapFunction<Row, Tuple2<Integer, String>> {
-        private final String[] inputCols;
-
-        public ExtractColumnIdAndString(String[] inputCols) {
-            this.inputCols = inputCols;
-        }
-
-        @Override
-        public void flatMap(Row row, Collector<Tuple2<Integer, String>> out) {
-            for (int i = 0; i < inputCols.length; i++) {
-                Object objVal = row.getField(inputCols[i]);
-                String stringVal;
-                if (objVal instanceof String) {
-                    stringVal = (String) objVal;
-                } else if (objVal instanceof Number) {
-                    stringVal = String.valueOf(objVal);
-                } else {
-                    throw new RuntimeException(
-                            "The input column only supports string and numeric type.");
-                }
-                out.collect(Tuple2.of(i, stringVal));
-            }
+            return new StringIndexerModelData(stringArrays);
         }
     }
 }