You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/03/23 02:28:37 UTC

[flink-ml] branch master updated: [FLINK-25527] Add Transformer and Estimator for StringIndexer

This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 8571317  [FLINK-25527] Add Transformer and Estimator for StringIndexer
8571317 is described below

commit 857131751e34e3bb1df4c063d9d1a4c57f5fd761
Author: Zhipeng Zhang <zh...@gmail.com>
AuthorDate: Wed Mar 23 10:28:33 2022 +0800

    [FLINK-25527] Add Transformer and Estimator for StringIndexer
    
    This closes #52.
---
 .../ml/common/datastream/DataStreamUtils.java      |  11 +-
 .../ml/common/datastream/DataStreamUtilsTest.java  |   5 +-
 .../feature/stringindexer/IndexToStringModel.java  | 174 +++++++++++
 .../stringindexer/IndexToStringModelParams.java    |  29 ++
 .../ml/feature/stringindexer/StringIndexer.java    | 238 +++++++++++++++
 .../feature/stringindexer/StringIndexerModel.java  | 203 +++++++++++++
 .../stringindexer/StringIndexerModelData.java      | 125 ++++++++
 .../stringindexer/StringIndexerModelParams.java    |  62 ++++
 .../feature/stringindexer/StringIndexerParams.java |  68 +++++
 .../stringindexer/IndexToStringModelTest.java      | 174 +++++++++++
 .../feature/stringindexer/StringIndexerTest.java   | 335 +++++++++++++++++++++
 11 files changed, 1417 insertions(+), 7 deletions(-)

diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
index 58eae62..7eea6b0 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -25,7 +25,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.streaming.api.datastream.DataStream;
-import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
 import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.TimestampedCollector;
@@ -71,15 +71,14 @@ public class DataStreamUtils {
      * A stream operator to apply {@link MapPartitionFunction} on each partition of the input
      * bounded data stream.
      */
-    private static class MapPartitionOperator<IN, OUT> extends AbstractStreamOperator<OUT>
+    private static class MapPartitionOperator<IN, OUT>
+            extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, OUT>>
             implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
 
-        private final MapPartitionFunction<IN, OUT> mapPartitionFunc;
-
         private ListState<IN> valuesState;
 
         public MapPartitionOperator(MapPartitionFunction<IN, OUT> mapPartitionFunc) {
-            this.mapPartitionFunc = mapPartitionFunc;
+            super(mapPartitionFunc);
         }
 
         @Override
@@ -95,7 +94,7 @@ public class DataStreamUtils {
 
         @Override
         public void endInput() throws Exception {
-            mapPartitionFunc.mapPartition(valuesState.get(), new TimestampedCollector<>(output));
+            userFunction.mapPartition(valuesState.get(), new TimestampedCollector<>(output));
             valuesState.clear();
         }
 
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
index 64c1d70..7933859 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.ml.common.datastream;
 
 import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
 import org.apache.flink.api.common.typeinfo.Types;
 import org.apache.flink.configuration.Configuration;
@@ -35,6 +36,7 @@ import org.junit.Test;
 import java.util.List;
 
 import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertNotNull;
 
 /** Tests the {@link DataStreamUtils}. */
 public class DataStreamUtilsTest {
@@ -62,9 +64,10 @@ public class DataStreamUtilsTest {
     }
 
     /** A simple implementation for a {@link MapPartitionFunction}. */
-    private static class TestMapPartitionFunc implements MapPartitionFunction<Long, Integer> {
+    private static class TestMapPartitionFunc extends RichMapPartitionFunction<Long, Integer> {
 
         public void mapPartition(Iterable<Long> values, Collector<Integer> out) {
+            assertNotNull(getRuntimeContext());
             int cnt = 0;
             for (long ignored : values) {
                 cnt++;
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModel.java
new file mode 100644
index 0000000..b66a337
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModel.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which transforms input index column(s) to string column(s) using the model data computed
+ * by {@link StringIndexer}. It is a reverse operation of {@link StringIndexerModel}.
+ */
+public class IndexToStringModel
+        implements Model<IndexToStringModel>, IndexToStringModelParams<IndexToStringModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public IndexToStringModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                StringIndexerModelData.getModelDataStream(modelDataTable),
+                path,
+                new StringIndexerModelData.ModelDataEncoder());
+    }
+
+    public static IndexToStringModel load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        IndexToStringModel model = ReadWriteUtils.loadStageParam(path);
+        DataStream<StringIndexerModelData> modelData =
+                ReadWriteUtils.loadModelData(
+                        env, path, new StringIndexerModelData.ModelDataDecoder());
+
+        return model.setModelData(tEnv.fromDataStream(modelData));
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public IndexToStringModel setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked, rawtypes")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        String[] inputCols = getInputCols();
+        String[] outputCols = getOutputCols();
+        Preconditions.checkArgument(inputCols.length == outputCols.length);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        TypeInformation<?>[] outputTypes = new TypeInformation[outputCols.length];
+        Arrays.fill(outputTypes, BasicTypeInfo.STRING_TYPE_INFO);
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputTypes),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCols));
+
+        final String broadcastModelKey = "broadcastModelKey";
+        DataStream<StringIndexerModelData> modelDataStream =
+                StringIndexerModelData.getModelDataStream(modelDataTable);
+
+        DataStream<Row> result =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(tEnv.toDataStream(inputs[0])),
+                        Collections.singletonMap(broadcastModelKey, modelDataStream),
+                        inputList -> {
+                            DataStream inputData = inputList.get(0);
+                            return inputData.flatMap(
+                                    new Index2String(broadcastModelKey, inputCols), outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(result)};
+    }
+
+    /** Maps the input index values to string values according to the model data. */
+    private static class Index2String extends RichFlatMapFunction<Row, Row> {
+        private String[][] stringArrays;
+        private final String broadcastModelKey;
+        private final String[] inputCols;
+
+        public Index2String(String broadcastModelKey, String[] inputCols) {
+            this.broadcastModelKey = broadcastModelKey;
+            this.inputCols = inputCols;
+        }
+
+        @Override
+        public void flatMap(Row input, Collector<Row> out) {
+            if (stringArrays == null) {
+                StringIndexerModelData modelData =
+                        (StringIndexerModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
+                stringArrays = modelData.stringArrays;
+            }
+
+            Row outputStrings = new Row(inputCols.length);
+            int stringId;
+            for (int i = 0; i < inputCols.length; i++) {
+                try {
+                    stringId = (Integer) input.getField(inputCols[i]);
+                } catch (Exception e) {
+                    throw new RuntimeException(
+                            "The input contains non-integer value: "
+                                    + input.getField(inputCols[i] + "."));
+                }
+                if (stringId < stringArrays[i].length && stringId >= 0) {
+                    outputStrings.setField(i, stringArrays[i][stringId]);
+                } else {
+                    throw new RuntimeException(
+                            "The input contains unseen index: " + stringId + ".");
+                }
+            }
+
+            out.collect(Row.join(input, outputStrings));
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelParams.java
new file mode 100644
index 0000000..2b29083
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelParams.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasOutputCols;
+
+/**
+ * Params for {@link IndexToStringModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface IndexToStringModelParams<T> extends HasInputCols<T>, HasOutputCols<T> {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
new file mode 100644
index 0000000..0a8f7e9
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
@@ -0,0 +1,238 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the string indexing algorithm.
+ *
+ * <p>A string indexer maps one or more columns (string/numerical value) of the input to one or more
+ * indexed output columns (integer value). The output indices of two data points are the same iff
+ * their corresponding input columns are the same. The indices are in [0,
+ * numDistinctValuesInThisColumn].
+ *
+ * <p>The input columns are cast to string if they are numeric values. By default, the output model
+ * is arbitrarily ordered. Users can control this by setting {@link
+ * StringIndexerParams#STRING_ORDER_TYPE}.
+ */
+public class StringIndexer
+        implements Estimator<StringIndexer, StringIndexerModel>,
+                StringIndexerParams<StringIndexer> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public StringIndexer() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StringIndexer load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public StringIndexerModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        String[] inputCols = getInputCols();
+        String[] outputCols = getOutputCols();
+        Preconditions.checkArgument(inputCols.length == outputCols.length);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Integer, String>> columnIdAndString =
+                tEnv.toDataStream(inputs[0]).flatMap(new ExtractColumnIdAndString(inputCols));
+
+        DataStream<Tuple3<Integer, String, Long>> columnIdAndStringAndCnt =
+                DataStreamUtils.mapPartition(
+                        columnIdAndString.keyBy(
+                                (KeySelector<Tuple2<Integer, String>, Integer>) Tuple2::hashCode),
+                        new CountStringsByColumn(inputCols.length));
+
+        DataStream<StringIndexerModelData> modelData =
+                DataStreamUtils.mapPartition(
+                        columnIdAndStringAndCnt,
+                        new GenerateModel(inputCols.length, getStringOrderType()));
+        modelData.getTransformation().setParallelism(1);
+
+        StringIndexerModel model =
+                new StringIndexerModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /**
+     * Merges all the extracted strings and generates the {@link StringIndexerModelData} according
+     * to the specified string order type.
+     */
+    private static class GenerateModel
+            implements MapPartitionFunction<Tuple3<Integer, String, Long>, StringIndexerModelData> {
+        private final int numCols;
+        private final String stringOrderType;
+
+        public GenerateModel(int numCols, String stringOrderType) {
+            this.numCols = numCols;
+            this.stringOrderType = stringOrderType;
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void mapPartition(
+                Iterable<Tuple3<Integer, String, Long>> values,
+                Collector<StringIndexerModelData> out) {
+            String[][] stringArrays = new String[numCols][];
+            ArrayList<Tuple2<String, Long>>[] stringsAndCntsByColumn = new ArrayList[numCols];
+            for (int i = 0; i < numCols; i++) {
+                stringsAndCntsByColumn[i] = new ArrayList<>();
+            }
+
+            for (Tuple3<Integer, String, Long> colIdAndStringAndCnt : values) {
+                stringsAndCntsByColumn[colIdAndStringAndCnt.f0].add(
+                        Tuple2.of(colIdAndStringAndCnt.f1, colIdAndStringAndCnt.f2));
+            }
+
+            for (int i = 0; i < stringsAndCntsByColumn.length; i++) {
+                List<Tuple2<String, Long>> stringsAndCnts = stringsAndCntsByColumn[i];
+                switch (stringOrderType) {
+                    case StringIndexerParams.ALPHABET_ASC_ORDER:
+                        stringsAndCnts.sort(Comparator.comparing(valAndCnt -> valAndCnt.f0));
+                        break;
+                    case StringIndexerParams.ALPHABET_DESC_ORDER:
+                        stringsAndCnts.sort(
+                                (valAndCnt1, valAndCnt2) ->
+                                        -valAndCnt1.f0.compareTo(valAndCnt2.f0));
+                        break;
+                    case StringIndexerParams.FREQUENCY_ASC_ORDER:
+                        stringsAndCnts.sort(Comparator.comparing(valAndCnt -> valAndCnt.f1));
+                        break;
+                    case StringIndexerParams.FREQUENCY_DESC_ORDER:
+                        stringsAndCnts.sort(
+                                (valAndCnt1, valAndCnt2) ->
+                                        -valAndCnt1.f1.compareTo(valAndCnt2.f1));
+                        break;
+                    case StringIndexerParams.ARBITRARY_ORDER:
+                        break;
+                    default:
+                        throw new IllegalStateException(
+                                "Unsupported string order type: " + stringOrderType);
+                }
+
+                stringArrays[i] = new String[stringsAndCnts.size()];
+                for (int stringId = 0; stringId < stringArrays[i].length; stringId++) {
+                    stringArrays[i][stringId] = stringsAndCnts.get(stringId).f0;
+                }
+            }
+
+            out.collect(new StringIndexerModelData(stringArrays));
+        }
+    }
+
+    /** Computes the frequency of strings in each column. */
+    private static class CountStringsByColumn
+            implements MapPartitionFunction<
+                    Tuple2<Integer, String>, Tuple3<Integer, String, Long>> {
+        private final int numCols;
+
+        public CountStringsByColumn(int numCols) {
+            this.numCols = numCols;
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void mapPartition(
+                Iterable<Tuple2<Integer, String>> values,
+                Collector<Tuple3<Integer, String, Long>> out) {
+            HashMap<String, Long>[] string2CntByColumn = new HashMap[numCols];
+            for (int i = 0; i < numCols; i++) {
+                string2CntByColumn[i] = new HashMap<>();
+            }
+            for (Tuple2<Integer, String> columnIdAndString : values) {
+                int colId = columnIdAndString.f0;
+                String stringVal = columnIdAndString.f1;
+                long cnt = string2CntByColumn[colId].getOrDefault(stringVal, 0L) + 1;
+                string2CntByColumn[colId].put(stringVal, cnt);
+            }
+            for (int i = 0; i < numCols; i++) {
+                for (Map.Entry<String, Long> entry : string2CntByColumn[i].entrySet()) {
+                    out.collect(Tuple3.of(i, entry.getKey(), entry.getValue()));
+                }
+            }
+        }
+    }
+
+    /** Extracts strings by column. */
+    private static class ExtractColumnIdAndString
+            implements FlatMapFunction<Row, Tuple2<Integer, String>> {
+        private final String[] inputCols;
+
+        public ExtractColumnIdAndString(String[] inputCols) {
+            this.inputCols = inputCols;
+        }
+
+        @Override
+        public void flatMap(Row row, Collector<Tuple2<Integer, String>> out) {
+            for (int i = 0; i < inputCols.length; i++) {
+                Object objVal = row.getField(inputCols[i]);
+                String stringVal;
+                if (objVal instanceof String) {
+                    stringVal = (String) objVal;
+                } else if (objVal instanceof Number) {
+                    stringVal = String.valueOf(objVal);
+                } else {
+                    throw new RuntimeException(
+                            "The input column only supports string and numeric type.");
+                }
+                out.collect(Tuple2.of(i, stringVal));
+            }
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
new file mode 100644
index 0000000..37307ba
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which transforms input string/numeric column(s) to integer column(s) using the model data
+ * computed by {@link StringIndexer}.
+ */
+public class StringIndexerModel
+        implements Model<StringIndexerModel>, StringIndexerModelParams<StringIndexerModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public StringIndexerModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                StringIndexerModelData.getModelDataStream(modelDataTable),
+                path,
+                new StringIndexerModelData.ModelDataEncoder());
+    }
+
+    public static StringIndexerModel load(StreamExecutionEnvironment env, String path)
+            throws IOException {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        StringIndexerModel model = ReadWriteUtils.loadStageParam(path);
+        DataStream<StringIndexerModelData> modelData =
+                ReadWriteUtils.loadModelData(
+                        env, path, new StringIndexerModelData.ModelDataDecoder());
+
+        return model.setModelData(tEnv.fromDataStream(modelData));
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public StringIndexerModel setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked, rawtypes")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        String[] inputCols = getInputCols();
+        String[] outputCols = getOutputCols();
+        Preconditions.checkArgument(inputCols.length == outputCols.length);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        TypeInformation<?>[] outputTypes = new TypeInformation[outputCols.length];
+        Arrays.fill(outputTypes, BasicTypeInfo.INT_TYPE_INFO);
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputTypes),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCols()));
+
+        final String broadcastModelKey = "broadcastModelKey";
+        DataStream<StringIndexerModelData> modelDataStream =
+                StringIndexerModelData.getModelDataStream(modelDataTable);
+
+        DataStream<Row> result =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(tEnv.toDataStream(inputs[0])),
+                        Collections.singletonMap(broadcastModelKey, modelDataStream),
+                        inputList -> {
+                            DataStream inputData = inputList.get(0);
+                            return inputData.flatMap(
+                                    new String2Index(
+                                            broadcastModelKey, inputCols, getHandleInvalid()),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(result)};
+    }
+
+    /** Maps the input columns to integer values according to the model data. */
+    private static class String2Index extends RichFlatMapFunction<Row, Row> {
+        private HashMap<String, Integer>[] modelDataMap;
+        private final String broadcastModelKey;
+        private final String[] inputCols;
+        private final String handleInValid;
+
+        public String2Index(String broadcastModelKey, String[] inputCols, String handleInValid) {
+            this.broadcastModelKey = broadcastModelKey;
+            this.inputCols = inputCols;
+            this.handleInValid = handleInValid;
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void flatMap(Row input, Collector<Row> out) {
+            if (modelDataMap == null) {
+                modelDataMap = new HashMap[inputCols.length];
+                StringIndexerModelData modelData =
+                        (StringIndexerModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
+                String[][] stringsArray = modelData.stringArrays;
+                for (int i = 0; i < stringsArray.length; i++) {
+                    int idx = 0;
+                    modelDataMap[i] = new HashMap<>(stringsArray[i].length);
+                    for (String string : stringsArray[i]) {
+                        modelDataMap[i].put(string, idx++);
+                    }
+                }
+            }
+
+            Row outputIndices = new Row(inputCols.length);
+            for (int i = 0; i < inputCols.length; i++) {
+                Object objVal = input.getField(inputCols[i]);
+                String stringVal;
+                if (objVal instanceof String) {
+                    stringVal = (String) objVal;
+                } else if (objVal instanceof Number) {
+                    stringVal = String.valueOf(objVal);
+                } else {
+                    throw new RuntimeException(
+                            "The input column only supports string and numeric type.");
+                }
+
+                if (modelDataMap[i].containsKey(stringVal)) {
+                    outputIndices.setField(i, modelDataMap[i].get(stringVal));
+                } else {
+                    switch (handleInValid) {
+                        case StringIndexerModelParams.SKIP_INVALID:
+                            return;
+                        case StringIndexerModelParams.ERROR_INVALID:
+                            throw new RuntimeException(
+                                    "The input contains unseen string: "
+                                            + stringVal
+                                            + ". See handleInvalid parameter for more options.");
+                        case StringIndexerModelParams.KEEP_INVALID:
+                            outputIndices.setField(i, modelDataMap[i].size());
+                            break;
+                        default:
+                            throw new IllegalStateException(
+                                    "Unsupported types of handling invalid data: " + handleInValid);
+                    }
+                }
+            }
+
+            out.collect(Row.join(input, outputIndices));
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModelData.java
new file mode 100644
index 0000000..2de8d38
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModelData.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.array.StringArraySerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link StringIndexerModel} and {@link IndexToStringModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to DataStream, and classes
+ * to save/load model data.
+ */
+public class StringIndexerModelData {
+    /** Ordered strings of each input column. */
+    public String[][] stringArrays;
+
+    public StringIndexerModelData(String[][] stringArrays) {
+        this.stringArrays = stringArrays;
+    }
+
+    public StringIndexerModelData() {}
+
+    /**
+     * Converts the table model to a data stream.
+     *
+     * @param modelData The table model data.
+     * @return The data stream model data.
+     */
+    public static DataStream<StringIndexerModelData> getModelDataStream(Table modelData) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
+        return tEnv.toDataStream(modelData)
+                .map(x -> new StringIndexerModelData((String[][]) x.getField(0)));
+    }
+
+    /** Data encoder for {@link StringIndexerModel} and {@link IndexToStringModel}. */
+    public static class ModelDataEncoder implements Encoder<StringIndexerModelData> {
+
+        @Override
+        public void encode(StringIndexerModelData modelData, OutputStream outputStream)
+                throws IOException {
+            DataOutputViewStreamWrapper outputViewStreamWrapper =
+                    new DataOutputViewStreamWrapper(outputStream);
+
+            IntSerializer.INSTANCE.serialize(
+                    modelData.stringArrays.length, outputViewStreamWrapper);
+            for (String[] strings : modelData.stringArrays) {
+                StringArraySerializer.INSTANCE.serialize(strings, outputViewStreamWrapper);
+            }
+        }
+    }
+
+    /** Data decoder for {@link StringIndexerModel} and {@link IndexToStringModel}. */
+    public static class ModelDataDecoder extends SimpleStreamFormat<StringIndexerModelData> {
+
+        @Override
+        public Reader<StringIndexerModelData> createReader(
+                Configuration configuration, FSDataInputStream inputStream) {
+            return new Reader<StringIndexerModelData>() {
+
+                @Override
+                public StringIndexerModelData read() throws IOException {
+                    try {
+                        DataInputViewStreamWrapper inputViewStreamWrapper =
+                                new DataInputViewStreamWrapper(inputStream);
+
+                        int numCols = IntSerializer.INSTANCE.deserialize(inputViewStreamWrapper);
+                        String[][] stringsArray = new String[numCols][];
+                        for (int i = 0; i < numCols; i++) {
+                            stringsArray[i] =
+                                    StringArraySerializer.INSTANCE.deserialize(
+                                            inputViewStreamWrapper);
+                        }
+
+                        return new StringIndexerModelData(stringsArray);
+                    } catch (EOFException e) {
+                        return null;
+                    }
+                }
+
+                @Override
+                public void close() throws IOException {
+                    inputStream.close();
+                }
+            };
+        }
+
+        @Override
+        public TypeInformation<StringIndexerModelData> getProducedType() {
+            return TypeInformation.of(StringIndexerModelData.class);
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModelParams.java
new file mode 100644
index 0000000..c4d81ad
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModelParams.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasOutputCols;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link StringIndexerModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface StringIndexerModelParams<T> extends HasInputCols<T>, HasOutputCols<T> {
+    String ERROR_INVALID = "error";
+    String SKIP_INVALID = "skip";
+    String KEEP_INVALID = "keep";
+
+    /**
+     * Supported options and the corresponding behavior to handle invalid entries in
+     * StringIndexerModel are listed as follows.
+     *
+     * <ul>
+     *   <li>error: raise an exception.
+     *   <li>skip: filter out rows with bad values.
+     *   <li>keep: put the invalid entries in a special bucket, whose index is the number of
+     *       distinct values in this column.
+     * </ul>
+     */
+    Param<String> HANDLE_INVALID =
+            new StringParam(
+                    "handleInvalid",
+                    "Strategy to handle invalid entries.",
+                    ERROR_INVALID,
+                    ParamValidators.inArray(ERROR_INVALID, SKIP_INVALID, KEEP_INVALID));
+
+    default String getHandleInvalid() {
+        return get(HANDLE_INVALID);
+    }
+
+    default T setHandleInvalid(String value) {
+        return set(HANDLE_INVALID, value);
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java
new file mode 100644
index 0000000..61c23cf
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link StringIndexer}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface StringIndexerParams<T> extends StringIndexerModelParams<T> {
+    String ARBITRARY_ORDER = "arbitrary";
+    String FREQUENCY_DESC_ORDER = "frequencyDesc";
+    String FREQUENCY_ASC_ORDER = "frequencyAsc";
+    String ALPHABET_DESC_ORDER = "alphabetDesc";
+    String ALPHABET_ASC_ORDER = "alphabetAsc";
+
+    /**
+     * Supported options to decide the order of strings in each column are listed as follows. (The
+     * first label after ordering is assigned an index of 0).
+     *
+     * <ul>
+     *   <li>arbitrary: the order of strings is arbitrary and depends on each execution.
+     *   <li>frequencyDesc: descending order by the frequency.
+     *   <li>frequencyAsc: ascending order by the frequency.
+     *   <li>alphabetDesc: descending alphabetical order.
+     *   <li>alphabetAsc: descending alphabetical order.
+     * </ul>
+     */
+    Param<String> STRING_ORDER_TYPE =
+            new StringParam(
+                    "stringOrderType",
+                    "How to order strings of each column.",
+                    ARBITRARY_ORDER,
+                    ParamValidators.inArray(
+                            ARBITRARY_ORDER,
+                            FREQUENCY_DESC_ORDER,
+                            FREQUENCY_ASC_ORDER,
+                            ALPHABET_DESC_ORDER,
+                            ALPHABET_ASC_ORDER));
+
+    default String getStringOrderType() {
+        return get(STRING_ORDER_TYPE);
+    }
+
+    default T setStringOrderType(String value) {
+        return set(STRING_ORDER_TYPE, value);
+    }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java
new file mode 100644
index 0000000..e02276a
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModelTest.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests the {@link IndexToStringModel}. */
+public class IndexToStringModelTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table predictTable;
+    private Table modelTable;
+    private Table predictTableWithUnseenValues;
+
+    private final List<Row> expectedPrediction =
+            Arrays.asList(Row.of(0, 3, "a", "2.0"), Row.of(1, 2, "b", "1.0"));
+    private final String[][] stringArrays =
+            new String[][] {{"a", "b", "c", "d"}, {"-1.0", "0.0", "1.0", "2.0"}};
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        modelTable =
+                tEnv.fromDataStream(env.fromElements(new StringIndexerModelData(stringArrays)))
+                        .as("stringArrays");
+        predictTable =
+                tEnv.fromDataStream(env.fromCollection(Arrays.asList(Row.of(0, 3), Row.of(1, 2))))
+                        .as("inputCol1", "inputCol2");
+        predictTableWithUnseenValues =
+                tEnv.fromDataStream(
+                                env.fromCollection(
+                                        Arrays.asList(Row.of(0, 3), Row.of(1, 2), Row.of(4, 1))))
+                        .as("inputCol1", "inputCol2");
+    }
+
+    @Test
+    public void testPredictParam() {
+        IndexToStringModel indexToStringModel =
+                new IndexToStringModel()
+                        .setInputCols("inputCol1", "inputCol2")
+                        .setOutputCols("outputCol1", "outputCol2")
+                        .setModelData(modelTable);
+        Table output = indexToStringModel.transform(predictTable)[0];
+
+        assertEquals(
+                Arrays.asList("inputCol1", "inputCol2", "outputCol1", "outputCol2"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testInputWithUnseenValues() {
+        IndexToStringModel indexToStringModel =
+                new IndexToStringModel()
+                        .setInputCols("inputCol1", "inputCol2")
+                        .setOutputCols("outputCol1", "outputCol2")
+                        .setModelData(modelTable);
+        Table output = indexToStringModel.transform(predictTableWithUnseenValues)[0];
+
+        try {
+            IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+            fail();
+        } catch (Exception e) {
+            assertEquals(
+                    "The input contains unseen index: 4.",
+                    e.getCause().getCause().getCause().getCause().getCause().getMessage());
+        }
+    }
+
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testPredict() throws Exception {
+        IndexToStringModel indexToStringModel =
+                new IndexToStringModel()
+                        .setInputCols("inputCol1", "inputCol2")
+                        .setOutputCols("outputCol1", "outputCol2")
+                        .setModelData(modelTable);
+        Table output = indexToStringModel.transform(predictTable)[0];
+
+        List<Row> predictedResult =
+                IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        StringIndexerTest.verifyPredictionResult(expectedPrediction, predictedResult);
+    }
+
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testSaveLoadAndPredict() throws Exception {
+        IndexToStringModel model =
+                new IndexToStringModel()
+                        .setInputCols("inputCol1", "inputCol2")
+                        .setOutputCols("outputCol1", "outputCol2")
+                        .setModelData(modelTable);
+        model = StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
+
+        assertEquals(
+                Collections.singletonList("stringArrays"),
+                model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+        Table output = model.transform(predictTable)[0];
+
+        List<Row> predictedResult =
+                IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        StringIndexerTest.verifyPredictionResult(expectedPrediction, predictedResult);
+    }
+
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testGetModelData() throws Exception {
+        IndexToStringModel model =
+                new IndexToStringModel()
+                        .setInputCols("inputCol1", "inputCol2")
+                        .setOutputCols("outputCol1", "outputCol2")
+                        .setModelData(modelTable);
+
+        List<StringIndexerModelData> collectedModelData =
+                (List<StringIndexerModelData>)
+                        (IteratorUtils.toList(
+                                StringIndexerModelData.getModelDataStream(model.getModelData()[0])
+                                        .executeAndCollect()));
+
+        assertEquals(1, collectedModelData.size());
+
+        StringIndexerModelData modelData = collectedModelData.get(0);
+        assertEquals(2, modelData.stringArrays.length);
+        assertArrayEquals(stringArrays[0], modelData.stringArrays[0]);
+        assertArrayEquals(stringArrays[1], modelData.stringArrays[1]);
+    }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
new file mode 100644
index 0000000..805da5b
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.stringindexer;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/** Tests the {@link StringIndexer} and {@link StringIndexerModel}. */
+public class StringIndexerTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainTable;
+    private Table predictTable;
+
+    private final String[][] expectedAlphabeticAscModelData =
+            new String[][] {{"a", "b", "c", "d"}, {"-1.0", "0.0", "1.0", "2.0"}};
+    private final List<Row> expectedAlphabeticAscPredictData =
+            Arrays.asList(Row.of("a", 2.0, 0, 3), Row.of("b", 1.0, 1, 2), Row.of("e", 2.0, 4, 3));
+    private final List<Row> expectedAlphabeticDescPredictData =
+            Arrays.asList(Row.of("a", 2.0, 3, 0), Row.of("b", 1.0, 2, 1), Row.of("e", 2.0, 4, 0));
+    private final List<Row> expectedFreqAscPredictData =
+            Arrays.asList(Row.of("a", 2.0, 2, 3), Row.of("b", 1.0, 3, 1), Row.of("e", 2.0, 4, 3));
+    private final List<Row> expectedFreqDescPredictData =
+            Arrays.asList(Row.of("a", 2.0, 1, 0), Row.of("b", 1.0, 0, 2), Row.of("e", 2.0, 4, 0));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        List<Row> trainData =
+                Arrays.asList(
+                        Row.of("a", 1.0),
+                        Row.of("b", 1.0),
+                        Row.of("b", 2.0),
+                        Row.of("c", 0.0),
+                        Row.of("d", 2.0),
+                        Row.of("a", 2.0),
+                        Row.of("b", 2.0),
+                        Row.of("b", -1.0),
+                        Row.of("a", -1.0),
+                        Row.of("c", -1.0));
+        trainTable =
+                tEnv.fromDataStream(env.fromCollection(trainData)).as("inputCol1", "inputCol2");
+
+        List<Row> predictData = Arrays.asList(Row.of("a", 2.0), Row.of("b", 1.0), Row.of("e", 2.0));
+        predictTable =
+                tEnv.fromDataStream(env.fromCollection(predictData)).as("inputCol1", "inputCol2");
+    }
+
+    @Test
+    public void testFitParam() {
+        StringIndexer stringIndexer = new StringIndexer();
+        assertEquals(stringIndexer.getStringOrderType(), StringIndexerParams.ARBITRARY_ORDER);
+        assertEquals(stringIndexer.getHandleInvalid(), StringIndexerParams.ERROR_INVALID);
+
+        stringIndexer
+                .setInputCols("inputCol1", "inputCol2")
+                .setOutputCols("outputCol1", "outputCol2")
+                .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+                .setHandleInvalid(StringIndexerParams.SKIP_INVALID);
+
+        assertArrayEquals(new String[] {"inputCol1", "inputCol2"}, stringIndexer.getInputCols());
+        assertArrayEquals(new String[] {"outputCol1", "outputCol2"}, stringIndexer.getOutputCols());
+        assertEquals(stringIndexer.getStringOrderType(), StringIndexerParams.ALPHABET_ASC_ORDER);
+        assertEquals(stringIndexer.getHandleInvalid(), StringIndexerParams.SKIP_INVALID);
+    }
+
+    @Test
+    public void testPredictParam() {
+        StringIndexer stringIndexer =
+                new StringIndexer()
+                        .setInputCols("inputCol1", "inputCol2")
+                        .setOutputCols("outputCol1", "outputCol2")
+                        .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+                        .setHandleInvalid(StringIndexerParams.SKIP_INVALID);
+        Table output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+
+        assertEquals(
+                Arrays.asList("inputCol1", "inputCol2", "outputCol1", "outputCol2"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    @SuppressWarnings("all")
+    public void testStringOrderType() throws Exception {
+        StringIndexer stringIndexer =
+                new StringIndexer()
+                        .setInputCols("inputCol1", "inputCol2")
+                        .setOutputCols("outputCol1", "outputCol2")
+                        .setHandleInvalid(StringIndexerParams.KEEP_INVALID);
+        Table output;
+        List<Row> predictedResult;
+
+        // AlphabetAsc order.
+        stringIndexer.setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER);
+        output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+        predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult);
+
+        // AlphabetDesc order.
+        stringIndexer.setStringOrderType(StringIndexerParams.ALPHABET_DESC_ORDER);
+        output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+        predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        verifyPredictionResult(expectedAlphabeticDescPredictData, predictedResult);
+
+        // FrequencyAsc order.
+        stringIndexer.setStringOrderType(StringIndexerParams.FREQUENCY_ASC_ORDER);
+        output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+        predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        verifyPredictionResult(expectedFreqAscPredictData, predictedResult);
+
+        // FrequencyDesc order.
+        stringIndexer.setStringOrderType(StringIndexerParams.FREQUENCY_DESC_ORDER);
+        output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+        predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        verifyPredictionResult(expectedFreqDescPredictData, predictedResult);
+
+        // Arbitrary order.
+        stringIndexer.setStringOrderType(StringIndexerParams.ARBITRARY_ORDER);
+        output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+        predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+
+        Set<Integer> distinctStringsCol1 = new HashSet<>();
+        Set<Integer> distinctStringsCol2 = new HashSet<>();
+        int index;
+        for (Row r : predictedResult) {
+            index = (Integer) r.getField(2);
+            distinctStringsCol1.add(index);
+            assertTrue(index >= 0 && index <= 4);
+            index = (Integer) r.getField(3);
+            assertTrue(index >= 0 && index <= 3);
+            distinctStringsCol2.add(index);
+        }
+
+        assertEquals(3, distinctStringsCol1.size());
+        assertEquals(2, distinctStringsCol2.size());
+    }
+
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testHandleInvalid() throws Exception {
+        StringIndexer stringIndexer =
+                new StringIndexer()
+                        .setInputCols("inputCol1", "inputCol2")
+                        .setOutputCols("outputCol1", "outputCol2")
+                        .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER);
+
+        Table output;
+        List<Row> expectedResult;
+
+        // Keeps invalid data.
+        stringIndexer.setHandleInvalid(StringIndexerParams.KEEP_INVALID);
+        output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+        List<Row> predictedResult =
+                IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult);
+
+        // Skips invalid data.
+        stringIndexer.setHandleInvalid(StringIndexerParams.SKIP_INVALID);
+        output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+        predictedResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        expectedResult = Arrays.asList(Row.of("a", 2.0, 0, 3), Row.of("b", 1.0, 1, 2));
+        verifyPredictionResult(expectedResult, predictedResult);
+
+        // Throws an exception on invalid data.
+        stringIndexer.setHandleInvalid(StringIndexerParams.ERROR_INVALID);
+        try {
+            output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+            IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+            fail();
+        } catch (Exception e) {
+            assertEquals(
+                    "The input contains unseen string: e. "
+                            + "See handleInvalid parameter for more options.",
+                    e.getCause().getCause().getCause().getCause().getCause().getMessage());
+        }
+    }
+
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testFitAndPredict() throws Exception {
+        StringIndexer stringIndexer =
+                new StringIndexer()
+                        .setInputCols("inputCol1", "inputCol2")
+                        .setOutputCols("outputCol1", "outputCol2")
+                        .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+                        .setHandleInvalid(StringIndexerParams.KEEP_INVALID);
+        Table output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+
+        List<Row> predictedResult =
+                IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult);
+    }
+
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testSaveLoadAndPredict() throws Exception {
+        StringIndexer stringIndexer =
+                new StringIndexer()
+                        .setInputCols("inputCol1", "inputCol2")
+                        .setOutputCols("outputCol1", "outputCol2")
+                        .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+                        .setHandleInvalid(StringIndexerParams.KEEP_INVALID);
+        stringIndexer =
+                StageTestUtils.saveAndReload(
+                        env, stringIndexer, tempFolder.newFolder().getAbsolutePath());
+
+        StringIndexerModel model = stringIndexer.fit(trainTable);
+        model = StageTestUtils.saveAndReload(env, model, tempFolder.newFolder().getAbsolutePath());
+
+        assertEquals(
+                Collections.singletonList("stringArrays"),
+                model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+        Table output = model.transform(predictTable)[0];
+        List<Row> predictedResult =
+                IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult);
+    }
+
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testGetModelData() throws Exception {
+        StringIndexerModel model =
+                new StringIndexer()
+                        .setInputCols("inputCol1", "inputCol2")
+                        .setOutputCols("outputCol1", "outputCol2")
+                        .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+                        .fit(trainTable);
+        Table modelDataTable = model.getModelData()[0];
+
+        assertEquals(
+                Collections.singletonList("stringArrays"),
+                modelDataTable.getResolvedSchema().getColumnNames());
+
+        List<StringIndexerModelData> collectedModelData =
+                (List<StringIndexerModelData>)
+                        (IteratorUtils.toList(
+                                StringIndexerModelData.getModelDataStream(modelDataTable)
+                                        .executeAndCollect()));
+        assertEquals(1, collectedModelData.size());
+
+        StringIndexerModelData modelData = collectedModelData.get(0);
+        assertEquals(2, modelData.stringArrays.length);
+        assertArrayEquals(expectedAlphabeticAscModelData[0], modelData.stringArrays[0]);
+        assertArrayEquals(expectedAlphabeticAscModelData[1], modelData.stringArrays[1]);
+    }
+
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testSetModelData() throws Exception {
+        StringIndexerModel model =
+                new StringIndexer()
+                        .setInputCols("inputCol1", "inputCol2")
+                        .setOutputCols("outputCol1", "outputCol2")
+                        .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+                        .setHandleInvalid(StringIndexerParams.KEEP_INVALID)
+                        .fit(trainTable);
+
+        StringIndexerModel newModel = new StringIndexerModel();
+        ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+        newModel.setModelData(model.getModelData());
+        Table output = newModel.transform(predictTable)[0];
+
+        List<Row> predictedResult =
+                IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        verifyPredictionResult(expectedAlphabeticAscPredictData, predictedResult);
+    }
+
+    static void verifyPredictionResult(List<Row> expected, List<Row> result) {
+        compareResultCollections(
+                expected,
+                result,
+                (row1, row2) -> {
+                    int arity = Math.min(row1.getArity(), row2.getArity());
+                    for (int i = 0; i < arity; i++) {
+                        int cmp =
+                                String.valueOf(row1.getField(i))
+                                        .compareTo(String.valueOf(row2.getField(i)));
+                        if (cmp != 0) {
+                            return cmp;
+                        }
+                    }
+                    return 0;
+                });
+    }
+}