You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/06/20 03:43:18 UTC
[flink-ml] 01/04: [FLINK-27877] Improve performance for StringIndexer
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit ba7760792827634664bb1962cc7ca6e9c161f255
Author: zhangzp <zh...@gmail.com>
AuthorDate: Mon Jun 20 09:42:17 2022 +0800
[FLINK-27877] Improve performance for StringIndexer
---
.../ml/feature/stringindexer/StringIndexer.java | 219 +++++++++++----------
1 file changed, 117 insertions(+), 102 deletions(-)
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
index ee560e0..c8312fa 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
@@ -18,31 +18,40 @@
package org.apache.flink.ml.feature.stringindexer;
-import org.apache.flink.api.common.functions.FlatMapFunction;
-import org.apache.flink.api.common.functions.MapPartitionFunction;
-import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeHint;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.param.HasHandleInvalid;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
-import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
+import java.util.Map.Entry;
/**
* An Estimator which implements the string indexing algorithm.
@@ -91,19 +100,34 @@ public class StringIndexer
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
- DataStream<Tuple2<Integer, String>> columnIdAndString =
- tEnv.toDataStream(inputs[0]).flatMap(new ExtractColumnIdAndString(inputCols));
+ DataStream<HashMap<String, Long>[]> localCountedString =
+ tEnv.toDataStream(inputs[0])
+ .transform(
+ "countStringOperator",
+ TypeInformation.of(new TypeHint<HashMap<String, Long>[]>() {}),
+ new CountStringOperator(inputCols));
- DataStream<Tuple3<Integer, String, Long>> columnIdAndStringAndCnt =
- DataStreamUtils.mapPartition(
- columnIdAndString.keyBy(
- (KeySelector<Tuple2<Integer, String>, Integer>) Tuple2::hashCode),
- new CountStringsByColumn(inputCols.length));
+ DataStream<HashMap<String, Long>[]> countedString =
+ DataStreamUtils.reduce(
+ localCountedString,
+ (ReduceFunction<HashMap<String, Long>[]>)
+ (value1, value2) -> {
+ for (int i = 0; i < value1.length; i++) {
+ for (Entry<String, Long> stringAndCnt :
+ value2[i].entrySet()) {
+ value1[i].compute(
+ stringAndCnt.getKey(),
+ (k, v) ->
+ (v == null
+ ? stringAndCnt.getValue()
+ : v + stringAndCnt.getValue()));
+ }
+ }
+ return value1;
+ });
DataStream<StringIndexerModelData> modelData =
- DataStreamUtils.mapPartition(
- columnIdAndStringAndCnt,
- new GenerateModel(inputCols.length, getStringOrderType()));
+ countedString.map(new ModelGenerator(getStringOrderType()));
modelData.getTransformation().setParallelism(1);
StringIndexerModel model =
@@ -112,38 +136,93 @@ public class StringIndexer
return model;
}
+ /** Computes the occurrence time of each string by columns. */
+ private static class CountStringOperator extends AbstractStreamOperator<HashMap<String, Long>[]>
+ implements OneInputStreamOperator<Row, HashMap<String, Long>[]>, BoundedOneInput {
+ /** The name of input columns. */
+ private final String[] inputCols;
+ /** The occurrence time of each string by column. */
+ private HashMap<String, Long>[] stringCntByColumn;
+ /** The state of stringCntByColumn. */
+ private ListState<HashMap<String, Long>[]> stringCntByColumnState;
+
+ public CountStringOperator(String[] inputCols) {
+ this.inputCols = inputCols;
+ stringCntByColumn = new HashMap[inputCols.length];
+ for (int i = 0; i < stringCntByColumn.length; i++) {
+ stringCntByColumn[i] = new HashMap<>();
+ }
+ }
+
+ @Override
+ public void endInput() {
+ output.collect(new StreamRecord<>(stringCntByColumn));
+ stringCntByColumnState.clear();
+ }
+
+ @Override
+ public void processElement(StreamRecord<Row> element) {
+ Row r = element.getValue();
+ for (int i = 0; i < inputCols.length; i++) {
+ Object objVal = r.getField(inputCols[i]);
+ String stringVal;
+ if (objVal instanceof String) {
+ stringVal = (String) objVal;
+ } else if (objVal instanceof Number) {
+ stringVal = String.valueOf(objVal);
+ } else {
+ throw new RuntimeException(
+ "The input column only supports string and numeric type.");
+ }
+ stringCntByColumn[i].compute(stringVal, (k, v) -> (v == null ? 1 : v + 1));
+ }
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws Exception {
+ super.initializeState(context);
+ stringCntByColumnState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "stringCntByColumnState",
+ TypeInformation.of(
+ new TypeHint<HashMap<String, Long>[]>() {})));
+
+ OperatorStateUtils.getUniqueElement(stringCntByColumnState, "stringCntByColumnState")
+ .ifPresent(x -> stringCntByColumn = x);
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws Exception {
+ super.snapshotState(context);
+ stringCntByColumnState.update(Collections.singletonList(stringCntByColumn));
+ }
+ }
+
/**
* Merges all the extracted strings and generates the {@link StringIndexerModelData} according
* to the specified string order type.
*/
- private static class GenerateModel
- implements MapPartitionFunction<Tuple3<Integer, String, Long>, StringIndexerModelData> {
- private final int numCols;
+ private static class ModelGenerator
+ implements MapFunction<HashMap<String, Long>[], StringIndexerModelData> {
private final String stringOrderType;
- public GenerateModel(int numCols, String stringOrderType) {
- this.numCols = numCols;
+ public ModelGenerator(String stringOrderType) {
this.stringOrderType = stringOrderType;
}
@Override
- @SuppressWarnings("unchecked")
- public void mapPartition(
- Iterable<Tuple3<Integer, String, Long>> values,
- Collector<StringIndexerModelData> out) {
+ public StringIndexerModelData map(HashMap<String, Long>[] value) {
+ int numCols = value.length;
String[][] stringArrays = new String[numCols][];
- ArrayList<Tuple2<String, Long>>[] stringsAndCntsByColumn = new ArrayList[numCols];
+ ArrayList<Tuple2<String, Long>> stringsAndCnts = new ArrayList<>();
for (int i = 0; i < numCols; i++) {
- stringsAndCntsByColumn[i] = new ArrayList<>();
- }
-
- for (Tuple3<Integer, String, Long> colIdAndStringAndCnt : values) {
- stringsAndCntsByColumn[colIdAndStringAndCnt.f0].add(
- Tuple2.of(colIdAndStringAndCnt.f1, colIdAndStringAndCnt.f2));
- }
-
- for (int i = 0; i < stringsAndCntsByColumn.length; i++) {
- List<Tuple2<String, Long>> stringsAndCnts = stringsAndCntsByColumn[i];
+ stringsAndCnts.clear();
+ stringsAndCnts.ensureCapacity(value[i].size());
+ for (Map.Entry<String, Long> entry : value[i].entrySet()) {
+ stringsAndCnts.add(Tuple2.of(entry.getKey(), entry.getValue()));
+ }
switch (stringOrderType) {
case ALPHABET_ASC_ORDER:
stringsAndCnts.sort(Comparator.comparing(valAndCnt -> valAndCnt.f0));
@@ -171,74 +250,10 @@ public class StringIndexer
+ stringOrderType
+ ".");
}
-
- stringArrays[i] = new String[stringsAndCnts.size()];
- for (int stringId = 0; stringId < stringArrays[i].length; stringId++) {
- stringArrays[i][stringId] = stringsAndCnts.get(stringId).f0;
- }
- }
-
- out.collect(new StringIndexerModelData(stringArrays));
- }
- }
-
- /** Computes the frequency of strings in each column. */
- private static class CountStringsByColumn
- implements MapPartitionFunction<
- Tuple2<Integer, String>, Tuple3<Integer, String, Long>> {
- private final int numCols;
-
- public CountStringsByColumn(int numCols) {
- this.numCols = numCols;
- }
-
- @Override
- @SuppressWarnings("unchecked")
- public void mapPartition(
- Iterable<Tuple2<Integer, String>> values,
- Collector<Tuple3<Integer, String, Long>> out) {
- HashMap<String, Long>[] string2CntByColumn = new HashMap[numCols];
- for (int i = 0; i < numCols; i++) {
- string2CntByColumn[i] = new HashMap<>();
+ stringArrays[i] = stringsAndCnts.stream().map(x -> x.f0).toArray(String[]::new);
}
- for (Tuple2<Integer, String> columnIdAndString : values) {
- int colId = columnIdAndString.f0;
- String stringVal = columnIdAndString.f1;
- long cnt = string2CntByColumn[colId].getOrDefault(stringVal, 0L) + 1;
- string2CntByColumn[colId].put(stringVal, cnt);
- }
- for (int i = 0; i < numCols; i++) {
- for (Map.Entry<String, Long> entry : string2CntByColumn[i].entrySet()) {
- out.collect(Tuple3.of(i, entry.getKey(), entry.getValue()));
- }
- }
- }
- }
- /** Extracts strings by column. */
- private static class ExtractColumnIdAndString
- implements FlatMapFunction<Row, Tuple2<Integer, String>> {
- private final String[] inputCols;
-
- public ExtractColumnIdAndString(String[] inputCols) {
- this.inputCols = inputCols;
- }
-
- @Override
- public void flatMap(Row row, Collector<Tuple2<Integer, String>> out) {
- for (int i = 0; i < inputCols.length; i++) {
- Object objVal = row.getField(inputCols[i]);
- String stringVal;
- if (objVal instanceof String) {
- stringVal = (String) objVal;
- } else if (objVal instanceof Number) {
- stringVal = String.valueOf(objVal);
- } else {
- throw new RuntimeException(
- "The input column only supports string and numeric type.");
- }
- out.collect(Tuple2.of(i, stringVal));
- }
+ return new StringIndexerModelData(stringArrays);
}
}
}