You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/03/23 02:28:37 UTC
[flink-ml] branch master updated: [FLINK-25527] Add Transformer and Estimator for StringIndexer
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 8571317 [FLINK-25527] Add Transformer and Estimator for StringIndexer
8571317 is described below
commit 857131751e34e3bb1df4c063d9d1a4c57f5fd761
Author: Zhipeng Zhang <zh...@gmail.com>
AuthorDate: Wed Mar 23 10:28:33 2022 +0800
[FLINK-25527] Add Transformer and Estimator for StringIndexer
This closes #52.
---
.../ml/common/datastream/DataStreamUtils.java | 11 +-
.../ml/common/datastream/DataStreamUtilsTest.java | 5 +-
.../feature/stringindexer/IndexToStringModel.java | 174 +++++++++++
.../stringindexer/IndexToStringModelParams.java | 29 ++
.../ml/feature/stringindexer/StringIndexer.java | 238 +++++++++++++++
.../feature/stringindexer/StringIndexerModel.java | 203 +++++++++++++
.../stringindexer/StringIndexerModelData.java | 125 ++++++++
.../stringindexer/StringIndexerModelParams.java | 62 ++++
.../feature/stringindexer/StringIndexerParams.java | 68 +++++
.../stringindexer/IndexToStringModelTest.java | 174 +++++++++++
.../feature/stringindexer/StringIndexerTest.java | 335 +++++++++++++++++++++
11 files changed, 1417 insertions(+), 7 deletions(-)
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
index 58eae62..7eea6b0 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -25,7 +25,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.TimestampedCollector;
@@ -71,15 +71,14 @@ public class DataStreamUtils {
* A stream operator to apply {@link MapPartitionFunction} on each partition of the input
* bounded data stream.
*/
- private static class MapPartitionOperator<IN, OUT> extends AbstractStreamOperator<OUT>
+ private static class MapPartitionOperator<IN, OUT>
+ extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, OUT>>
implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
- private final MapPartitionFunction<IN, OUT> mapPartitionFunc;
-
private ListState<IN> valuesState;
public MapPartitionOperator(MapPartitionFunction<IN, OUT> mapPartitionFunc) {
- this.mapPartitionFunc = mapPartitionFunc;
+ super(mapPartitionFunc);
}
@Override
@@ -95,7 +94,7 @@ public class DataStreamUtils {
@Override
public void endInput() throws Exception {
- mapPartitionFunc.mapPartition(valuesState.get(), new TimestampedCollector<>(output));
+ userFunction.mapPartition(valuesState.get(), new TimestampedCollector<>(output));
valuesState.clear();
}
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
index 64c1d70..7933859 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
@@ -19,6 +19,7 @@
package org.apache.flink.ml.common.datastream;
import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.configuration.Configuration;
@@ -35,6 +36,7 @@ import org.junit.Test;
import java.util.List;
import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertNotNull;
/** Tests the {@link DataStreamUtils}. */
public class DataStreamUtilsTest {
@@ -62,9 +64,10 @@ public class DataStreamUtilsTest {
}
/** A simple implementation for a {@link MapPartitionFunction}. */
- private static class TestMapPartitionFunc implements MapPartitionFunction<Long, Integer> {
+ private static class TestMapPartitionFunc extends RichMapPartitionFunction<Long, Integer> {
public void mapPartition(Iterable<Long> values, Collector<Integer> out) {
+ assertNotNull(getRuntimeContext());
int cnt = 0;
for (long ignored : values) {
cnt++;
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModel.java
new file mode 100644
index 0000000..b66a337
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModel.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which transforms input index column(s) to string column(s) using the model data computed
+ * by {@link StringIndexer}. It is a reverse operation of {@link StringIndexerModel}.
+ */
+public class IndexToStringModel
+ implements Model<IndexToStringModel>, IndexToStringModelParams<IndexToStringModel> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private Table modelDataTable;
+
+ public IndexToStringModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ StringIndexerModelData.getModelDataStream(modelDataTable),
+ path,
+ new StringIndexerModelData.ModelDataEncoder());
+ }
+
+ public static IndexToStringModel load(StreamExecutionEnvironment env, String path)
+ throws IOException {
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ IndexToStringModel model = ReadWriteUtils.loadStageParam(path);
+ DataStream<StringIndexerModelData> modelData =
+ ReadWriteUtils.loadModelData(
+ env, path, new StringIndexerModelData.ModelDataDecoder());
+
+ return model.setModelData(tEnv.fromDataStream(modelData));
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public IndexToStringModel setModelData(Table... inputs) {
+ modelDataTable = inputs[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ @Override
+ @SuppressWarnings("unchecked, rawtypes")
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ String[] inputCols = getInputCols();
+ String[] outputCols = getOutputCols();
+ Preconditions.checkArgument(inputCols.length == outputCols.length);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+
+ RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ TypeInformation<?>[] outputTypes = new TypeInformation[outputCols.length];
+ Arrays.fill(outputTypes, BasicTypeInfo.STRING_TYPE_INFO);
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputTypes),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCols));
+
+ final String broadcastModelKey = "broadcastModelKey";
+ DataStream<StringIndexerModelData> modelDataStream =
+ StringIndexerModelData.getModelDataStream(modelDataTable);
+
+ DataStream<Row> result =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(tEnv.toDataStream(inputs[0])),
+ Collections.singletonMap(broadcastModelKey, modelDataStream),
+ inputList -> {
+ DataStream inputData = inputList.get(0);
+ return inputData.flatMap(
+ new Index2String(broadcastModelKey, inputCols), outputTypeInfo);
+ });
+
+ return new Table[] {tEnv.fromDataStream(result)};
+ }
+
+ /** Maps the input index values to string values according to the model data. */
+ private static class Index2String extends RichFlatMapFunction<Row, Row> {
+ private String[][] stringArrays;
+ private final String broadcastModelKey;
+ private final String[] inputCols;
+
+ public Index2String(String broadcastModelKey, String[] inputCols) {
+ this.broadcastModelKey = broadcastModelKey;
+ this.inputCols = inputCols;
+ }
+
+ @Override
+ public void flatMap(Row input, Collector<Row> out) {
+ if (stringArrays == null) {
+ StringIndexerModelData modelData =
+ (StringIndexerModelData)
+ getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
+ stringArrays = modelData.stringArrays;
+ }
+
+ Row outputStrings = new Row(inputCols.length);
+ int stringId;
+ for (int i = 0; i < inputCols.length; i++) {
+ try {
+ stringId = (Integer) input.getField(inputCols[i]);
+ } catch (Exception e) {
+ throw new RuntimeException(
+ "The input contains non-integer value: "
+ + input.getField(inputCols[i] + "."));
+ }
+ if (stringId < stringArrays[i].length && stringId >= 0) {
+ outputStrings.setField(i, stringArrays[i][stringId]);
+ } else {
+ throw new RuntimeException(
+ "The input contains unseen index: " + stringId + ".");
+ }
+ }
+
+ out.collect(Row.join(input, outputStrings));
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelParams.java
new file mode 100644
index 0000000..2b29083
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelParams.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasOutputCols;
+
+/**
+ * Params for {@link IndexToStringModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface IndexToStringModelParams<T> extends HasInputCols<T>, HasOutputCols<T> {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
new file mode 100644
index 0000000..0a8f7e9
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the string indexing algorithm.
+ *
+ * <p>A string indexer maps one or more columns (string/numerical value) of the input to one or more
+ * indexed output columns (integer value). The output indices of two data points are the same iff
+ * their corresponding input columns are the same. The indices are in [0,
+ * numDistinctValuesInThisColumn].
+ *
+ * <p>The input columns are cast to string if they are numeric values. By default, the output model
+ * is arbitrarily ordered. Users can control this by setting {@link
+ * StringIndexerParams#STRING_ORDER_TYPE}.
+ */
+public class StringIndexer
+ implements Estimator<StringIndexer, StringIndexerModel>,
+ StringIndexerParams<StringIndexer> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public StringIndexer() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static StringIndexer load(StreamExecutionEnvironment env, String path)
+ throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public StringIndexerModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ String[] inputCols = getInputCols();
+ String[] outputCols = getOutputCols();
+ Preconditions.checkArgument(inputCols.length == outputCols.length);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ DataStream<Tuple2<Integer, String>> columnIdAndString =
+ tEnv.toDataStream(inputs[0]).flatMap(new ExtractColumnIdAndString(inputCols));
+
+ DataStream<Tuple3<Integer, String, Long>> columnIdAndStringAndCnt =
+ DataStreamUtils.mapPartition(
+ columnIdAndString.keyBy(
+ (KeySelector<Tuple2<Integer, String>, Integer>) Tuple2::hashCode),
+ new CountStringsByColumn(inputCols.length));
+
+ DataStream<StringIndexerModelData> modelData =
+ DataStreamUtils.mapPartition(
+ columnIdAndStringAndCnt,
+ new GenerateModel(inputCols.length, getStringOrderType()));
+ modelData.getTransformation().setParallelism(1);
+
+ StringIndexerModel model =
+ new StringIndexerModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, paramMap);
+ return model;
+ }
+
+ /**
+ * Merges all the extracted strings and generates the {@link StringIndexerModelData} according
+ * to the specified string order type.
+ */
+ private static class GenerateModel
+ implements MapPartitionFunction<Tuple3<Integer, String, Long>, StringIndexerModelData> {
+ private final int numCols;
+ private final String stringOrderType;
+
+ public GenerateModel(int numCols, String stringOrderType) {
+ this.numCols = numCols;
+ this.stringOrderType = stringOrderType;
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void mapPartition(
+ Iterable<Tuple3<Integer, String, Long>> values,
+ Collector<StringIndexerModelData> out) {
+ String[][] stringArrays = new String[numCols][];
+ ArrayList<Tuple2<String, Long>>[] stringsAndCntsByColumn = new ArrayList[numCols];
+ for (int i = 0; i < numCols; i++) {
+ stringsAndCntsByColumn[i] = new ArrayList<>();
+ }
+
+ for (Tuple3<Integer, String, Long> colIdAndStringAndCnt : values) {
+ stringsAndCntsByColumn[colIdAndStringAndCnt.f0].add(
+ Tuple2.of(colIdAndStringAndCnt.f1, colIdAndStringAndCnt.f2));
+ }
+
+ for (int i = 0; i < stringsAndCntsByColumn.length; i++) {
+ List<Tuple2<String, Long>> stringsAndCnts = stringsAndCntsByColumn[i];
+ switch (stringOrderType) {
+ case StringIndexerParams.ALPHABET_ASC_ORDER:
+ stringsAndCnts.sort(Comparator.comparing(valAndCnt -> valAndCnt.f0));
+ break;
+ case StringIndexerParams.ALPHABET_DESC_ORDER:
+ stringsAndCnts.sort(
+ (valAndCnt1, valAndCnt2) ->
+ -valAndCnt1.f0.compareTo(valAndCnt2.f0));
+ break;
+ case StringIndexerParams.FREQUENCY_ASC_ORDER:
+ stringsAndCnts.sort(Comparator.comparing(valAndCnt -> valAndCnt.f1));
+ break;
+ case StringIndexerParams.FREQUENCY_DESC_ORDER:
+ stringsAndCnts.sort(
+ (valAndCnt1, valAndCnt2) ->
+ -valAndCnt1.f1.compareTo(valAndCnt2.f1));
+ break;
+ case StringIndexerParams.ARBITRARY_ORDER:
+ break;
+ default:
+ throw new IllegalStateException(
+ "Unsupported string order type: " + stringOrderType);
+ }
+
+ stringArrays[i] = new String[stringsAndCnts.size()];
+ for (int stringId = 0; stringId < stringArrays[i].length; stringId++) {
+ stringArrays[i][stringId] = stringsAndCnts.get(stringId).f0;
+ }
+ }
+
+ out.collect(new StringIndexerModelData(stringArrays));
+ }
+ }
+
+ /** Computes the frequency of strings in each column. */
+ private static class CountStringsByColumn
+ implements MapPartitionFunction<
+ Tuple2<Integer, String>, Tuple3<Integer, String, Long>> {
+ private final int numCols;
+
+ public CountStringsByColumn(int numCols) {
+ this.numCols = numCols;
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void mapPartition(
+ Iterable<Tuple2<Integer, String>> values,
+ Collector<Tuple3<Integer, String, Long>> out) {
+ HashMap<String, Long>[] string2CntByColumn = new HashMap[numCols];
+ for (int i = 0; i < numCols; i++) {
+ string2CntByColumn[i] = new HashMap<>();
+ }
+ for (Tuple2<Integer, String> columnIdAndString : values) {
+ int colId = columnIdAndString.f0;
+ String stringVal = columnIdAndString.f1;
+ long cnt = string2CntByColumn[colId].getOrDefault(stringVal, 0L) + 1;
+ string2CntByColumn[colId].put(stringVal, cnt);
+ }
+ for (int i = 0; i < numCols; i++) {
+ for (Map.Entry<String, Long> entry : string2CntByColumn[i].entrySet()) {
+ out.collect(Tuple3.of(i, entry.getKey(), entry.getValue()));
+ }
+ }
+ }
+ }
+
+ /** Extracts strings by column. */
+ private static class ExtractColumnIdAndString
+ implements FlatMapFunction<Row, Tuple2<Integer, String>> {
+ private final String[] inputCols;
+
+ public ExtractColumnIdAndString(String[] inputCols) {
+ this.inputCols = inputCols;
+ }
+
+ @Override
+ public void flatMap(Row row, Collector<Tuple2<Integer, String>> out) {
+ for (int i = 0; i < inputCols.length; i++) {
+ Object objVal = row.getField(inputCols[i]);
+ String stringVal;
+ if (objVal instanceof String) {
+ stringVal = (String) objVal;
+ } else if (objVal instanceof Number) {
+ stringVal = String.valueOf(objVal);
+ } else {
+ throw new RuntimeException(
+ "The input column only supports string and numeric type.");
+ }
+ out.collect(Tuple2.of(i, stringVal));
+ }
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
new file mode 100644
index 0000000..37307ba
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which transforms input string/numeric column(s) to integer column(s) using the model data
+ * computed by {@link StringIndexer}.
+ */
+public class StringIndexerModel
+ implements Model<StringIndexerModel>, StringIndexerModelParams<StringIndexerModel> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private Table modelDataTable;
+
+ public StringIndexerModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ StringIndexerModelData.getModelDataStream(modelDataTable),
+ path,
+ new StringIndexerModelData.ModelDataEncoder());
+ }
+
+ public static StringIndexerModel load(StreamExecutionEnvironment env, String path)
+ throws IOException {
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ StringIndexerModel model = ReadWriteUtils.loadStageParam(path);
+ DataStream<StringIndexerModelData> modelData =
+ ReadWriteUtils.loadModelData(
+ env, path, new StringIndexerModelData.ModelDataDecoder());
+
+ return model.setModelData(tEnv.fromDataStream(modelData));
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public StringIndexerModel setModelData(Table... inputs) {
+ modelDataTable = inputs[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ @Override
+ @SuppressWarnings("unchecked, rawtypes")
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ String[] inputCols = getInputCols();
+ String[] outputCols = getOutputCols();
+ Preconditions.checkArgument(inputCols.length == outputCols.length);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+
+ RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ TypeInformation<?>[] outputTypes = new TypeInformation[outputCols.length];
+ Arrays.fill(outputTypes, BasicTypeInfo.INT_TYPE_INFO);
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputTypes),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCols()));
+
+ final String broadcastModelKey = "broadcastModelKey";
+ DataStream<StringIndexerModelData> modelDataStream =
+ StringIndexerModelData.getModelDataStream(modelDataTable);
+
+ DataStream<Row> result =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(tEnv.toDataStream(inputs[0])),
+ Collections.singletonMap(broadcastModelKey, modelDataStream),
+ inputList -> {
+ DataStream inputData = inputList.get(0);
+ return inputData.flatMap(
+ new String2Index(
+ broadcastModelKey, inputCols, getHandleInvalid()),
+ outputTypeInfo);
+ });
+
+ return new Table[] {tEnv.fromDataStream(result)};
+ }
+
+ /** Maps the input columns to integer values according to the model data. */
+ private static class String2Index extends RichFlatMapFunction<Row, Row> {
+ private HashMap<String, Integer>[] modelDataMap;
+ private final String broadcastModelKey;
+ private final String[] inputCols;
+ private final String handleInValid;
+
+ public String2Index(String broadcastModelKey, String[] inputCols, String handleInValid) {
+ this.broadcastModelKey = broadcastModelKey;
+ this.inputCols = inputCols;
+ this.handleInValid = handleInValid;
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void flatMap(Row input, Collector<Row> out) {
+ if (modelDataMap == null) {
+ modelDataMap = new HashMap[inputCols.length];
+ StringIndexerModelData modelData =
+ (StringIndexerModelData)
+ getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
+ String[][] stringsArray = modelData.stringArrays;
+ for (int i = 0; i < stringsArray.length; i++) {
+ int idx = 0;
+ modelDataMap[i] = new HashMap<>(stringsArray[i].length);
+ for (String string : stringsArray[i]) {
+ modelDataMap[i].put(string, idx++);
+ }
+ }
+ }
+
+ Row outputIndices = new Row(inputCols.length);
+ for (int i = 0; i < inputCols.length; i++) {
+ Object objVal = input.getField(inputCols[i]);
+ String stringVal;
+ if (objVal instanceof String) {
+ stringVal = (String) objVal;
+ } else if (objVal instanceof Number) {
+ stringVal = String.valueOf(objVal);
+ } else {
+ throw new RuntimeException(
+ "The input column only supports string and numeric type.");
+ }
+
+ if (modelDataMap[i].containsKey(stringVal)) {
+ outputIndices.setField(i, modelDataMap[i].get(stringVal));
+ } else {
+ switch (handleInValid) {
+ case StringIndexerModelParams.SKIP_INVALID:
+ return;
+ case StringIndexerModelParams.ERROR_INVALID:
+ throw new RuntimeException(
+ "The input contains unseen string: "
+ + stringVal
+ + ". See handleInvalid parameter for more options.");
+ case StringIndexerModelParams.KEEP_INVALID:
+ outputIndices.setField(i, modelDataMap[i].size());
+ break;
+ default:
+ throw new IllegalStateException(
+ "Unsupported types of handling invalid data: " + handleInValid);
+ }
+ }
+ }
+
+ out.collect(Row.join(input, outputIndices));
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModelData.java
new file mode 100644
index 0000000..2de8d38
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModelData.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.array.StringArraySerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link StringIndexerModel} and {@link IndexToStringModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to DataStream, and classes
+ * to save/load model data.
+ */
+public class StringIndexerModelData {
+ /** Ordered strings of each input column. */
+ public String[][] stringArrays;
+
+ public StringIndexerModelData(String[][] stringArrays) {
+ this.stringArrays = stringArrays;
+ }
+
+ public StringIndexerModelData() {}
+
+ /**
+ * Converts the table model to a data stream.
+ *
+ * @param modelData The table model data.
+ * @return The data stream model data.
+ */
+ public static DataStream<StringIndexerModelData> getModelDataStream(Table modelData) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
+ return tEnv.toDataStream(modelData)
+ .map(x -> new StringIndexerModelData((String[][]) x.getField(0)));
+ }
+
+ /** Data encoder for {@link StringIndexerModel} and {@link IndexToStringModel}. */
+ public static class ModelDataEncoder implements Encoder<StringIndexerModelData> {
+
+ @Override
+ public void encode(StringIndexerModelData modelData, OutputStream outputStream)
+ throws IOException {
+ DataOutputViewStreamWrapper outputViewStreamWrapper =
+ new DataOutputViewStreamWrapper(outputStream);
+
+ IntSerializer.INSTANCE.serialize(
+ modelData.stringArrays.length, outputViewStreamWrapper);
+ for (String[] strings : modelData.stringArrays) {
+ StringArraySerializer.INSTANCE.serialize(strings, outputViewStreamWrapper);
+ }
+ }
+ }
+
+ /** Data decoder for {@link StringIndexerModel} and {@link IndexToStringModel}. */
+ public static class ModelDataDecoder extends SimpleStreamFormat<StringIndexerModelData> {
+
+ @Override
+ public Reader<StringIndexerModelData> createReader(
+ Configuration configuration, FSDataInputStream inputStream) {
+ return new Reader<StringIndexerModelData>() {
+
+ @Override
+ public StringIndexerModelData read() throws IOException {
+ try {
+ DataInputViewStreamWrapper inputViewStreamWrapper =
+ new DataInputViewStreamWrapper(inputStream);
+
+ int numCols = IntSerializer.INSTANCE.deserialize(inputViewStreamWrapper);
+ String[][] stringsArray = new String[numCols][];
+ for (int i = 0; i < numCols; i++) {
+ stringsArray[i] =
+ StringArraySerializer.INSTANCE.deserialize(
+ inputViewStreamWrapper);
+ }
+
+ return new StringIndexerModelData(stringsArray);
+ } catch (EOFException e) {
+ return null;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ inputStream.close();
+ }
+ };
+ }
+
+ @Override
+ public TypeInformation<StringIndexerModelData> getProducedType() {
+ return TypeInformation.of(StringIndexerModelData.class);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModelParams.java
new file mode 100644
index 0000000..c4d81ad
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModelParams.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasOutputCols;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link StringIndexerModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface StringIndexerModelParams<T> extends HasInputCols<T>, HasOutputCols<T> {
+ String ERROR_INVALID = "error";
+ String SKIP_INVALID = "skip";
+ String KEEP_INVALID = "keep";
+
+ /**
+ * Supported options and the corresponding behavior to handle invalid entries in
+ * StringIndexerModel are listed as follows.
+ *
+ * <ul>
+ * <li>error: raise an exception.
+ * <li>skip: filter out rows with bad values.
+ * <li>keep: put the invalid entries in a special bucket, whose index is the number of
+ * distinct values in this column.
+ * </ul>
+ */
+ Param<String> HANDLE_INVALID =
+ new StringParam(
+ "handleInvalid",
+ "Strategy to handle invalid entries.",
+ ERROR_INVALID,
+ ParamValidators.inArray(ERROR_INVALID, SKIP_INVALID, KEEP_INVALID));
+
+ default String getHandleInvalid() {
+ return get(HANDLE_INVALID);
+ }
+
+ default T setHandleInvalid(String value) {
+ return set(HANDLE_INVALID, value);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java
new file mode 100644
index 0000000..61c23cf
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link StringIndexer}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface StringIndexerParams<T> extends StringIndexerModelParams<T> {
+ String ARBITRARY_ORDER = "arbitrary";
+ String FREQUENCY_DESC_ORDER = "frequencyDesc";
+ String FREQUENCY_ASC_ORDER = "frequencyAsc";
+ String ALPHABET_DESC_ORDER = "alphabetDesc";
+ String ALPHABET_ASC_ORDER = "alphabetAsc";
+
+ /**
+ * Supported options to decide the order of strings in each column are listed as follows. (The
+ * first label after ordering is assigned an index of 0).
+ *
+ * <ul>
+ * <li>arbitrary: the order of strings is arbitrary and depends on each execution.
+ * <li>frequencyDesc: descending order by the frequency.
+ * <li>frequencyAsc: ascending order by the frequency.
+ * <li>alphabetDesc: descending alphabetical order.
+ * <li>alphabetAsc: descending alphabetical order.
+ * </ul>
+ */
+ Param<String> STRING_ORDER_TYPE =
+ new StringParam(
+ "stringOrderType",
+ "How to order strings of each column.",
+ ARBITRARY_ORDER,
+ ParamValidators.inArray(
+ ARBITRARY_ORDER,
+ FREQUENCY_DESC_ORDER,
+ FREQUENCY_ASC_ORDER,
+ ALPHABET_DESC_ORDER,
+ ALPHABET_ASC_ORDER));
+
+ default String getStringOrderType() {
+ return get(STRING_ORDER_TYPE);
+ }
+
+ default T setStringOrderType(String value) {
+ return set(STRING_ORDER_TYPE, value);
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java
new file mode 100644
index 0000000..e02276a
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests the {@link IndexToStringModel}. */
+public class IndexToStringModelTest extends AbstractTestBase {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table predictTable;
+ private Table modelTable;
+ private Table predictTableWithUnseenValues;
+
+ private final List<Row> expectedPrediction =
+ Arrays.asList(Row.of(0, 3, "a", "2.0"), Row.of(1, 2, "b", "1.0"));
+ private final String[][] stringArrays =
+ new String[][] {{"a", "b", "c", "d"}, {"-1.0", "0.0", "1.0", "2.0"}};
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ modelTable =
+ tEnv.fromDataStream(env.fromElements(new StringIndexerModelData(stringArrays)))
+ .as("stringArrays");
+ predictTable =
+ tEnv.fromDataStream(env.fromCollection(Arrays.asList(Row.of(0, 3), Row.of(1, 2))))
+ .as("inputCol1", "inputCol2");
+ predictTableWithUnseenValues =
+ tEnv.fromDataStream(
+ env.fromCollection(
+ Arrays.asList(Row.of(0, 3), Row.of(1, 2), Row.of(4, 1))))
+ .as("inputCol1", "inputCol2");
+ }
+
+ @Test
+ public void testPredictParam() {
+ IndexToStringModel indexToStringModel =
+ new IndexToStringModel()
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setModelData(modelTable);
+ Table output = indexToStringModel.transform(predictTable)[0];
+
+ assertEquals(
+ Arrays.asList("inputCol1", "inputCol2", "outputCol1", "outputCol2"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testInputWithUnseenValues() {
+ IndexToStringModel indexToStringModel =
+ new IndexToStringModel()
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setModelData(modelTable);
+ Table output = indexToStringModel.transform(predictTableWithUnseenValues)[0];
+
+ try {
+ IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ fail();
+ } catch (Exception e) {
+ assertEquals(
+ "The input contains unseen index: 4.",
+ e.getCause().getCause().getCause().getCause().getCause().getMessage());
+ }
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testPredict() throws Exception {
+ IndexToStringModel indexToStringModel =
+ new IndexToStringModel()
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setModelData(modelTable);
+ Table output = indexToStringModel.transform(predictTable)[0];
+
+ List<Row> predictedResult =
+ IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ StringIndexerTest.verifyPredictionResult(expectedPrediction, predictedResult);
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testSaveLoadAndPredict() throws Exception {
+ IndexToStringModel model =
+ new IndexToStringModel()
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setModelData(modelTable);
+ model = StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
+
+ assertEquals(
+ Collections.singletonList("stringArrays"),
+ model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+ Table output = model.transform(predictTable)[0];
+
+ List<Row> predictedResult =
+ IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ StringIndexerTest.verifyPredictionResult(expectedPrediction, predictedResult);
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testGetModelData() throws Exception {
+ IndexToStringModel model =
+ new IndexToStringModel()
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setModelData(modelTable);
+
+ List<StringIndexerModelData> collectedModelData =
+ (List<StringIndexerModelData>)
+ (IteratorUtils.toList(
+ StringIndexerModelData.getModelDataStream(model.getModelData()[0])
+ .executeAndCollect()));
+
+ assertEquals(1, collectedModelData.size());
+
+ StringIndexerModelData modelData = collectedModelData.get(0);
+ assertEquals(2, modelData.stringArrays.length);
+ assertArrayEquals(stringArrays[0], modelData.stringArrays[0]);
+ assertArrayEquals(stringArrays[1], modelData.stringArrays[1]);
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
new file mode 100644
index 0000000..805da5b
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/** Tests the {@link StringIndexer} and {@link StringIndexerModel}. */
+public class StringIndexerTest extends AbstractTestBase {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainTable;
+ private Table predictTable;
+
+ private final String[][] expectedAlphabeticAscModelData =
+ new String[][] {{"a", "b", "c", "d"}, {"-1.0", "0.0", "1.0", "2.0"}};
+ private final List<Row> expectedAlphabeticAscPredictData =
+ Arrays.asList(Row.of("a", 2.0, 0, 3), Row.of("b", 1.0, 1, 2), Row.of("e", 2.0, 4, 3));
+ private final List<Row> expectedAlphabeticDescPredictData =
+ Arrays.asList(Row.of("a", 2.0, 3, 0), Row.of("b", 1.0, 2, 1), Row.of("e", 2.0, 4, 0));
+ private final List<Row> expectedFreqAscPredictData =
+ Arrays.asList(Row.of("a", 2.0, 2, 3), Row.of("b", 1.0, 3, 1), Row.of("e", 2.0, 4, 3));
+ private final List<Row> expectedFreqDescPredictData =
+ Arrays.asList(Row.of("a", 2.0, 1, 0), Row.of("b", 1.0, 0, 2), Row.of("e", 2.0, 4, 0));
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ List<Row> trainData =
+ Arrays.asList(
+ Row.of("a", 1.0),
+ Row.of("b", 1.0),
+ Row.of("b", 2.0),
+ Row.of("c", 0.0),
+ Row.of("d", 2.0),
+ Row.of("a", 2.0),
+ Row.of("b", 2.0),
+ Row.of("b", -1.0),
+ Row.of("a", -1.0),
+ Row.of("c", -1.0));
+ trainTable =
+ tEnv.fromDataStream(env.fromCollection(trainData)).as("inputCol1", "inputCol2");
+
+ List<Row> predictData = Arrays.asList(Row.of("a", 2.0), Row.of("b", 1.0), Row.of("e", 2.0));
+ predictTable =
+ tEnv.fromDataStream(env.fromCollection(predictData)).as("inputCol1", "inputCol2");
+ }
+
+ @Test
+ public void testFitParam() {
+ StringIndexer stringIndexer = new StringIndexer();
+ assertEquals(stringIndexer.getStringOrderType(), StringIndexerParams.ARBITRARY_ORDER);
+ assertEquals(stringIndexer.getHandleInvalid(), StringIndexerParams.ERROR_INVALID);
+
+ stringIndexer
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+ .setHandleInvalid(StringIndexerParams.SKIP_INVALID);
+
+ assertArrayEquals(new String[] {"inputCol1", "inputCol2"}, stringIndexer.getInputCols());
+ assertArrayEquals(new String[] {"outputCol1", "outputCol2"}, stringIndexer.getOutputCols());
+ assertEquals(stringIndexer.getStringOrderType(), StringIndexerParams.ALPHABET_ASC_ORDER);
+ assertEquals(stringIndexer.getHandleInvalid(), StringIndexerParams.SKIP_INVALID);
+ }
+
+ @Test
+ public void testPredictParam() {
+ StringIndexer stringIndexer =
+ new StringIndexer()
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+ .setHandleInvalid(StringIndexerParams.SKIP_INVALID);
+ Table output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+
+ assertEquals(
+ Arrays.asList("inputCol1", "inputCol2", "outputCol1", "outputCol2"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ @SuppressWarnings("all")
+ public void testStringOrderType() throws Exception {
+ StringIndexer stringIndexer =
+ new StringIndexer()
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setHandleInvalid(StringIndexerParams.KEEP_INVALID);
+ Table output;
+ List<Row> predictedResult;
+
+ // AlphabetAsc order.
+ stringIndexer.setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER);
+ output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+ predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult);
+
+ // AlphabetDesc order.
+ stringIndexer.setStringOrderType(StringIndexerParams.ALPHABET_DESC_ORDER);
+ output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+ predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ verifyPredictionResult(expectedAlphabeticDescPredictData, predictedResult);
+
+ // FrequencyAsc order.
+ stringIndexer.setStringOrderType(StringIndexerParams.FREQUENCY_ASC_ORDER);
+ output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+ predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ verifyPredictionResult(expectedFreqAscPredictData, predictedResult);
+
+ // FrequencyDesc order.
+ stringIndexer.setStringOrderType(StringIndexerParams.FREQUENCY_DESC_ORDER);
+ output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+ predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ verifyPredictionResult(expectedFreqDescPredictData, predictedResult);
+
+ // Arbitrary order.
+ stringIndexer.setStringOrderType(StringIndexerParams.ARBITRARY_ORDER);
+ output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+ predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+
+ Set<Integer> distinctStringsCol1 = new HashSet<>();
+ Set<Integer> distinctStringsCol2 = new HashSet<>();
+ int index;
+ for (Row r : predictedResult) {
+ index = (Integer) r.getField(2);
+ distinctStringsCol1.add(index);
+ assertTrue(index >= 0 && index <= 4);
+ index = (Integer) r.getField(3);
+ assertTrue(index >= 0 && index <= 3);
+ distinctStringsCol2.add(index);
+ }
+
+ assertEquals(3, distinctStringsCol1.size());
+ assertEquals(2, distinctStringsCol2.size());
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testHandleInvalid() throws Exception {
+ StringIndexer stringIndexer =
+ new StringIndexer()
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER);
+
+ Table output;
+ List<Row> expectedResult;
+
+ // Keeps invalid data.
+ stringIndexer.setHandleInvalid(StringIndexerParams.KEEP_INVALID);
+ output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+ List<Row> predictedResult =
+ IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult);
+
+ // Skips invalid data.
+ stringIndexer.setHandleInvalid(StringIndexerParams.SKIP_INVALID);
+ output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+ predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ expectedResult = Arrays.asList(Row.of("a", 2.0, 0, 3), Row.of("b", 1.0, 1, 2));
+ verifyPredictionResult(expectedResult, predictedResult);
+
+ // Throws an exception on invalid data.
+ stringIndexer.setHandleInvalid(StringIndexerParams.ERROR_INVALID);
+ try {
+ output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+ IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ fail();
+ } catch (Exception e) {
+ assertEquals(
+ "The input contains unseen string: e. "
+ + "See handleInvalid parameter for more options.",
+ e.getCause().getCause().getCause().getCause().getCause().getMessage());
+ }
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testFitAndPredict() throws Exception {
+ StringIndexer stringIndexer =
+ new StringIndexer()
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+ .setHandleInvalid(StringIndexerParams.KEEP_INVALID);
+ Table output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+
+ List<Row> predictedResult =
+ IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult);
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testSaveLoadAndPredict() throws Exception {
+ StringIndexer stringIndexer =
+ new StringIndexer()
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+ .setHandleInvalid(StringIndexerParams.KEEP_INVALID);
+ stringIndexer =
+ StageTestUtils.saveAndReload(
+ env, stringIndexer, tempFolder.newFolder().getAbsolutePath());
+
+ StringIndexerModel model = stringIndexer.fit(trainTable);
+ model = StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
+
+ assertEquals(
+ Collections.singletonList("stringArrays"),
+ model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+ Table output = model.transform(predictTable)[0];
+ List<Row> predictedResult =
+ IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult);
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testGetModelData() throws Exception {
+ StringIndexerModel model =
+ new StringIndexer()
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+ .fit(trainTable);
+ Table modelDataTable = model.getModelData()[0];
+
+ assertEquals(
+ Collections.singletonList("stringArrays"),
+ modelDataTable.getResolvedSchema().getColumnNames());
+
+ List<StringIndexerModelData> collectedModelData =
+ (List<StringIndexerModelData>)
+ (IteratorUtils.toList(
+ StringIndexerModelData.getModelDataStream(modelDataTable)
+ .executeAndCollect()));
+ assertEquals(1, collectedModelData.size());
+
+ StringIndexerModelData modelData = collectedModelData.get(0);
+ assertEquals(2, modelData.stringArrays.length);
+ assertArrayEquals(expectedAlphabeticAscModelData[0], modelData.stringArrays[0]);
+ assertArrayEquals(expectedAlphabeticAscModelData[1], modelData.stringArrays[1]);
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testSetModelData() throws Exception {
+ StringIndexerModel model =
+ new StringIndexer()
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+ .setHandleInvalid(StringIndexerParams.KEEP_INVALID)
+ .fit(trainTable);
+
+ StringIndexerModel newModel = new StringIndexerModel();
+ ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ newModel.setModelData(model.getModelData());
+ Table output = newModel.transform(predictTable)[0];
+
+ List<Row> predictedResult =
+ IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult);
+ }
+
+ static void verifyPredictionResult(List<Row> expected, List<Row> result) {
+ compareResultCollections(
+ expected,
+ result,
+ (row1, row2) -> {
+ int arity = Math.min(row1.getArity(), row2.getArity());
+ for (int i = 0; i < arity; i++) {
+ int cmp =
+ String.valueOf(row1.getField(i))
+ .compareTo(String.valueOf(row2.getField(i)));
+ if (cmp != 0) {
+ return cmp;
+ }
+ }
+ return 0;
+ });
+ }
+}