You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/11/29 09:17:32 UTC

[flink-ml] branch master updated: [FLINK-30124] Support collecting model data with arrays and maps

This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new d19b76c  [FLINK-30124] Support collecting model data with arrays and maps
d19b76c is described below

commit d19b76cface06604c3656e1dfbaa194c5c9c041c
Author: yunfengzhou-hub <yu...@outlook.com>
AuthorDate: Tue Nov 29 17:17:26 2022 +0800

    [FLINK-30124] Support collecting model data with arrays and maps
    
    This closes #181.
---
 flink-ml-dist/pom.xml                              |   6 +
 flink-ml-dist/src/main/assemblies/bin.xml          |   1 +
 .../ml/classification/naivebayes/NaiveBayes.java   |  18 +-
 .../apache/flink/ml/feature/imputer/Imputer.java   |  10 +-
 .../ml/feature/vectorindexer/VectorIndexer.java    |  13 +-
 flink-ml-python/pom.xml                            |  14 ++
 flink-ml-python/pyflink/ml/core/wrapper.py         |  38 +++-
 .../ml/lib/classification/tests/test_naivebayes.py |  23 ++-
 .../pyflink/ml/lib/clustering/tests/test_kmeans.py |  13 +-
 .../pyflink/ml/lib/feature/tests/test_idf.py       |  11 +-
 .../pyflink/ml/lib/feature/tests/test_imputer.py   |   9 +-
 .../lib/feature/tests/test_indextostringmodel.py   |   7 +-
 .../ml/lib/feature/tests/test_kbinsdiscretizer.py  |  11 +-
 .../ml/lib/feature/tests/test_stringindexer.py     |   7 +-
 .../tests/test_variancethresholdselector.py        |   7 +-
 .../ml/lib/feature/tests/test_vectorindexer.py     |   8 +-
 .../apache/flink/ml/python/PythonBridgeUtils.java  | 226 +++++++++++++++++++++
 17 files changed, 402 insertions(+), 20 deletions(-)

diff --git a/flink-ml-dist/pom.xml b/flink-ml-dist/pom.xml
index 1e04c50..20c8464 100644
--- a/flink-ml-dist/pom.xml
+++ b/flink-ml-dist/pom.xml
@@ -46,6 +46,12 @@ under the License.
             <version>${project.version}</version>
         </dependency>
 
+        <dependency>
+            <groupId>org.apache.flink</groupId>
+            <artifactId>flink-ml-python</artifactId>
+            <version>${project.version}</version>
+        </dependency>
+
         <!-- Stateful Functions Dependencies -->
 
         <dependency>
diff --git a/flink-ml-dist/src/main/assemblies/bin.xml b/flink-ml-dist/src/main/assemblies/bin.xml
index 987e27b..516c55f 100644
--- a/flink-ml-dist/src/main/assemblies/bin.xml
+++ b/flink-ml-dist/src/main/assemblies/bin.xml
@@ -37,6 +37,7 @@ under the License.
                 <include>org.apache.flink:statefun-flink-core</include>
                 <include>org.apache.flink:flink-ml-uber</include>
                 <include>org.apache.flink:flink-ml-examples</include>
+                <include>org.apache.flink:flink-ml-python</include>
             </includes>
         </dependencySet>
     </dependencySets>
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
index a4803c7..e5bd995 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
@@ -28,10 +28,13 @@ import org.apache.flink.ml.api.Estimator;
 import org.apache.flink.ml.common.datastream.DataStreamUtils;
 import org.apache.flink.ml.linalg.Vector;
 import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
@@ -104,7 +107,20 @@ public class NaiveBayes
                         aggregatedArrays, new GenerateModelFunction(smoothing));
         modelData.getTransformation().setParallelism(1);
 
-        NaiveBayesModel model = new NaiveBayesModel().setModelData(tEnv.fromDataStream(modelData));
+        Schema schema =
+                Schema.newBuilder()
+                        .column(
+                                "theta",
+                                DataTypes.ARRAY(
+                                        DataTypes.ARRAY(
+                                                DataTypes.MAP(
+                                                        DataTypes.DOUBLE(), DataTypes.DOUBLE()))))
+                        .column("piArray", DataTypes.of(DenseVectorTypeInfo.INSTANCE))
+                        .column("labels", DataTypes.of(DenseVectorTypeInfo.INSTANCE))
+                        .build();
+
+        NaiveBayesModel model =
+                new NaiveBayesModel().setModelData(tEnv.fromDataStream(modelData, schema));
         ReadWriteUtils.updateExistingParams(model, paramMap);
         return model;
     }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java
index b0b7e67..9ac4272 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java
@@ -27,6 +27,8 @@ import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
@@ -90,7 +92,13 @@ public class Imputer implements Estimator<Imputer, ImputerModel>, ImputerParams<
             default:
                 throw new RuntimeException("Unsupported strategy of Imputer: " + getStrategy());
         }
-        ImputerModel model = new ImputerModel().setModelData(tEnv.fromDataStream(modelData));
+
+        Schema schema =
+                Schema.newBuilder()
+                        .column("surrogates", DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE()))
+                        .build();
+        ImputerModel model =
+                new ImputerModel().setModelData(tEnv.fromDataStream(modelData, schema));
         ReadWriteUtils.updateExistingParams(model, getParamMap());
         return model;
     }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
index 74d6a6b..810a820 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
@@ -39,6 +39,8 @@ import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
@@ -117,8 +119,17 @@ public class VectorIndexer
                 distinctDoubles.map(new ModelGenerator(maxCategories));
         modelData.getTransformation().setParallelism(1);
 
+        Schema schema =
+                Schema.newBuilder()
+                        .column(
+                                "categoryMaps",
+                                DataTypes.MAP(
+                                        DataTypes.INT(),
+                                        DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.INT())))
+                        .build();
+
         VectorIndexerModel model =
-                new VectorIndexerModel().setModelData(tEnv.fromDataStream(modelData));
+                new VectorIndexerModel().setModelData(tEnv.fromDataStream(modelData, schema));
         ReadWriteUtils.updateExistingParams(model, paramMap);
         return model;
     }
diff --git a/flink-ml-python/pom.xml b/flink-ml-python/pom.xml
index 316ec40..dc5459e 100644
--- a/flink-ml-python/pom.xml
+++ b/flink-ml-python/pom.xml
@@ -30,6 +30,20 @@ under the License.
     <packaging>jar</packaging>
 
     <dependencies>
+        <dependency>
+            <groupId>org.apache.flink</groupId>
+            <artifactId>flink-table-common</artifactId>
+            <version>${flink.version}</version>
+            <scope>provided</scope>
+        </dependency>
+
+        <dependency>
+            <groupId>org.apache.flink</groupId>
+            <artifactId>flink-python_2.12</artifactId>
+            <version>${flink.version}</version>
+            <scope>provided</scope>
+        </dependency>
+
         <dependency>
             <groupId>org.apache.flink</groupId>
             <artifactId>flink-runtime</artifactId>
diff --git a/flink-ml-python/pyflink/ml/core/wrapper.py b/flink-ml-python/pyflink/ml/core/wrapper.py
index e965c0c..83acd9d 100644
--- a/flink-ml-python/pyflink/ml/core/wrapper.py
+++ b/flink-ml-python/pyflink/ml/core/wrapper.py
@@ -15,12 +15,16 @@
 #  See the License for the specific language governing permissions and
 # limitations under the License.
 ################################################################################
+import pickle
 from abc import ABC, abstractmethod
 from typing import List, Dict, Any
 
 from py4j.java_gateway import JavaObject, get_java_class
-from pyflink.common import typeinfo, Time
-from pyflink.common.typeinfo import _from_java_type, TypeInformation, _is_instance_of
+from pyflink.common import typeinfo, Time, Row, RowKind
+from pyflink.common.typeinfo import _from_java_type, TypeInformation, _is_instance_of, Types, \
+    ExternalTypeInfo, RowTypeInfo, TupleTypeInfo
+from pyflink.datastream import utils
+from pyflink.datastream.utils import pickled_bytes_to_python_converter
 from pyflink.java_gateway import get_gateway
 from pyflink.table import Table, StreamTableEnvironment, Expression
 from pyflink.util.java_utils import to_jarray
@@ -55,6 +59,36 @@ def _from_java_type_wrapper(j_type_info: JavaObject) -> TypeInformation:
 typeinfo._from_java_type = _from_java_type_wrapper
 
 
+# TODO: Remove this class after Flink ML depends on a Flink version
+#  with FLINK-30168 and FLINK-29477 fixed.
+def convert_to_python_obj_wrapper(data, type_info):
+    if type_info == Types.PICKLED_BYTE_ARRAY():
+        return pickle.loads(data)
+    elif isinstance(type_info, ExternalTypeInfo):
+        return convert_to_python_obj_wrapper(data, type_info._type_info)
+    else:
+        gateway = get_gateway()
+        pickle_bytes = gateway.jvm.org.apache.flink.ml.python.PythonBridgeUtils. \
+            getPickledBytesFromJavaObject(data, type_info.get_java_type_info())
+        if isinstance(type_info, RowTypeInfo) or isinstance(type_info, TupleTypeInfo):
+            field_data = zip(list(pickle_bytes[1:]), type_info.get_field_types())
+            fields = []
+            for data, field_type in field_data:
+                if len(data) == 0:
+                    fields.append(None)
+                else:
+                    fields.append(pickled_bytes_to_python_converter(data, field_type))
+            if isinstance(type_info, RowTypeInfo):
+                return Row.of_kind(RowKind(int.from_bytes(pickle_bytes[0], 'little')), *fields)
+            else:
+                return tuple(fields)
+        else:
+            return pickled_bytes_to_python_converter(pickle_bytes, type_info)
+
+
+utils.convert_to_python_obj = convert_to_python_obj_wrapper
+
+
 class JavaWrapper(ABC):
     """
     Wrapper class for a Java object
diff --git a/flink-ml-python/pyflink/ml/lib/classification/tests/test_naivebayes.py b/flink-ml-python/pyflink/ml/lib/classification/tests/test_naivebayes.py
index 16f0c8c..b735a71 100644
--- a/flink-ml-python/pyflink/ml/lib/classification/tests/test_naivebayes.py
+++ b/flink-ml-python/pyflink/ml/lib/classification/tests/test_naivebayes.py
@@ -113,12 +113,31 @@ class NaiveBayesTest(PyFlinkMLTestCase):
         self.assertEqual(self.expected_output, actual_output)
 
     def test_get_model_data(self):
-        model = self.estimator.fit(self.train_data)
+        train_data = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (Vectors.dense([1, 1.]), 11.),
+                (Vectors.dense([2, 1]), 11.),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['features', 'label'],
+                    [DenseVectorTypeInfo(), Types.DOUBLE()])))
+
+        model = self.estimator.fit(train_data)
         model_data = model.get_model_data()[0]
         expected_field_names = ['theta', 'piArray', 'labels']
         self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
 
-        # TODO: Add test to collect and verify the model data results after FLINK-30124 is resolved.
+        model_rows = [result for result in
+                      self.t_env.to_data_stream(model_data).execute_and_collect()]
+        self.assertEqual(1, len(model_rows))
+        self.assertListAlmostEqual(
+            [11.], model_rows[0][expected_field_names.index('labels')].to_array(), delta=1e-5)
+        self.assertListAlmostEqual(
+            [0.], model_rows[0][expected_field_names.index('piArray')].to_array(), delta=1e-5)
+        theta = model_rows[0][expected_field_names.index('theta')]
+        self.assertAlmostEqual(-0.6931471805599453, theta[0][0].get(1.0), delta=1e-5)
+        self.assertAlmostEqual(-0.6931471805599453, theta[0][0].get(2.0), delta=1e-5)
+        self.assertAlmostEqual(0.0, theta[0][1].get(1.0), delta=1e-5)
 
     def test_set_model_data(self):
         model_a = self.estimator.fit(self.train_data)
diff --git a/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py b/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py
index c63b26b..6b816a1 100644
--- a/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py
+++ b/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py
@@ -47,7 +47,7 @@ class KMeansTest(PyFlinkMLTestCase):
             self.env.from_collection([
                 (Vectors.dense([0.0, 0.0]),),
                 (Vectors.dense([0.0, 0.3]),),
-                (Vectors.dense([0.3, 3.0]),),
+                (Vectors.dense([0.3, 0.0]),),
                 (Vectors.dense([9.0, 0.0]),),
                 (Vectors.dense([9.0, 0.6]),),
                 (Vectors.dense([9.6, 0.0]),),
@@ -56,7 +56,7 @@ class KMeansTest(PyFlinkMLTestCase):
                     ['features'],
                     [DenseVectorTypeInfo()])))
         self.expected_groups = [
-            {DenseVector([0.0, 0.3]), DenseVector([0.3, 3.0]), DenseVector([0.0, 0.0])},
+            {DenseVector([0.0, 0.3]), DenseVector([0.3, 0.0]), DenseVector([0.0, 0.0])},
             {DenseVector([9.6, 0.0]), DenseVector([9.0, 0.0]), DenseVector([9.0, 0.6])}]
 
     def test_param(self):
@@ -153,7 +153,14 @@ class KMeansTest(PyFlinkMLTestCase):
         expected_field_names = ['centroids', 'weights']
         self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
 
-        # TODO: Add test to collect and verify the model data results after FLINK-30122 is resolved.
+        model_rows = [result for result in
+                      self.t_env.to_data_stream(model_data).execute_and_collect()]
+        self.assertEqual(1, len(model_rows))
+        centroids = model_rows[0][expected_field_names.index('centroids')]
+        self.assertEqual(2, len(centroids))
+        centroids.sort(key=lambda x: x.get(0))
+        self.assertListAlmostEqual([0.1, 0.1], centroids[0], delta=1e-5)
+        self.assertListAlmostEqual([9.2, 0.2], centroids[1], delta=1e-5)
 
     def test_set_model_data(self):
         kmeans = KMeans().set_max_iter(2).set_k(2)
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py
index 11f1887..7d276d5 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py
@@ -119,8 +119,15 @@ class IDFTest(PyFlinkMLTestCase):
         expected_field_names = ['idf', 'docFreq', 'numDocs']
         self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
 
-        # TODO: Add test to collect and verify the model data results after Flink dependency
-        #  is upgraded to 1.15.3, 1.16.0 or a higher version. Related ticket: FLINK-29477
+        model_rows = [result for result in
+                      self.t_env.to_data_stream(model_data).execute_and_collect()]
+        self.assertEqual(1, len(model_rows))
+        self.assertEqual(3, model_rows[0][expected_field_names.index('numDocs')])
+        self.assertListEqual([0, 3, 1, 2], model_rows[0][expected_field_names.index('docFreq')])
+        self.assertListAlmostEqual(
+            [1.3862943, 0, 0.6931471, 0.2876820],
+            model_rows[0][expected_field_names.index('idf')].to_array(),
+            delta=self.tolerance)
 
     def test_set_model_data(self):
         idf = IDF()
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_imputer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_imputer.py
index 8f045ac..c47b755 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_imputer.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_imputer.py
@@ -69,6 +69,7 @@ class ImputerTest(PyFlinkMLTestCase):
             'median': self.expected_median_strategy_output,
             'most_frequent': self.expected_most_frequent_strategy_output
         }
+        self.eps = 1e-5
 
     def test_param(self):
         imputer = Imputer().\
@@ -116,7 +117,13 @@ class ImputerTest(PyFlinkMLTestCase):
         expected_field_names = ['surrogates']
         self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
 
-        # TODO: Add test to collect and verify the model data results after FLINK-30124 is resolved.
+        model_rows = [result for result in
+                      self.t_env.to_data_stream(model_data).execute_and_collect()]
+        self.assertEqual(1, len(model_rows))
+        surrogates = model_rows[0][expected_field_names.index('surrogates')]
+        self.assertAlmostEqual(2.0, surrogates['f1'], delta=self.eps)
+        self.assertAlmostEqual(6.8, surrogates['f2'], delta=self.eps)
+        self.assertAlmostEqual(2.0, surrogates['f3'], delta=self.eps)
 
     def test_set_model_data(self):
         imputer = Imputer().\
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_indextostringmodel.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_indextostringmodel.py
index 41ef7bd..bf13ddc 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_indextostringmodel.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_indextostringmodel.py
@@ -85,7 +85,12 @@ class IndexToStringModelTest(PyFlinkMLTestCase):
         expected_field_names = ['stringArrays']
         self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
 
-        # TODO: Add test to collect and verify the model data results after FLINK-30122 is resolved.
+        model_rows = [result for result in
+                      self.t_env.to_data_stream(model_data).execute_and_collect()]
+        self.assertEqual(1, len(model_rows))
+        string_arrays = model_rows[0][expected_field_names.index('stringArrays')]
+        self.assertListEqual(["a", "b", "c", "d"], string_arrays[0])
+        self.assertListEqual(["-1.0", "0.0", "1.0", "2.0"], string_arrays[1])
 
     def test_save_load_and_predict(self):
         model = IndexToStringModel() \
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py
index 70bc9a5..4d50c56 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py
@@ -88,6 +88,8 @@ class KBinsDiscretizerTest(PyFlinkMLTestCase):
             Vectors.dense(2, 0, 2),
         ]
 
+        self.eps = 1e-7
+
     def test_param(self):
         k_bins_discretizer = KBinsDiscretizer()
 
@@ -163,7 +165,14 @@ class KBinsDiscretizerTest(PyFlinkMLTestCase):
         expected_field_names = ['binEdges']
         self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
 
-        # TODO: Add test to collect and verify the model data results after FLINK-30122 is resolved.
+        model_rows = [result for result in
+                      self.t_env.to_data_stream(model_data).execute_and_collect()]
+        self.assertEqual(1, len(model_rows))
+        bin_edges = model_rows[0][expected_field_names.index('binEdges')]
+        self.assertEqual(3, len(bin_edges))
+        self.assertListEqual([1, 5, 9, 13], bin_edges[0])
+        self.assertListEqual([4.9e-324, 1.7976931348623157e+308], bin_edges[1])
+        self.assertListEqual([0, 1, 2, 3], bin_edges[2])
 
     def test_set_model_data(self):
         k_bins_discretizer = KBinsDiscretizer().set_num_bins(3).set_strategy('uniform')
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_stringindexer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_stringindexer.py
index 0f6204b..673fac6 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_stringindexer.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_stringindexer.py
@@ -129,7 +129,12 @@ class StringIndexerTest(PyFlinkMLTestCase):
         expected_field_names = ['stringArrays']
         self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
 
-        # TODO: Add test to collect and verify the model data results after FLINK-30122 is resolved.
+        model_rows = [result for result in
+                      self.t_env.to_data_stream(model_data).execute_and_collect()]
+        self.assertEqual(1, len(model_rows))
+        string_arrays = model_rows[0][expected_field_names.index('stringArrays')]
+        self.assertListEqual(["a", "b", "c", "d"], string_arrays[0])
+        self.assertListEqual(["-1.0", "0.0", "1.0", "2.0"], string_arrays[1])
 
     def test_set_model_data(self):
         string_indexer = StringIndexer() \
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_variancethresholdselector.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_variancethresholdselector.py
index 013b383..428cb79 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_variancethresholdselector.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_variancethresholdselector.py
@@ -121,8 +121,11 @@ class VarianceThresholdSelectorTest(PyFlinkMLTestCase):
         expected_field_names = ['numOfFeatures', 'indices']
         self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
 
-        # TODO: Add test to collect and verify the model data results after Flink dependency
-        #  is upgraded to 1.15.3, 1.16.0 or a higher version. Related ticket: FLINK-29477
+        model_rows = [result for result in
+                      self.t_env.to_data_stream(model_data).execute_and_collect()]
+        self.assertEqual(1, len(model_rows))
+        self.assertEqual(6, model_rows[0][expected_field_names.index('numOfFeatures')])
+        self.assertListEqual([0, 3, 5], model_rows[0][expected_field_names.index('indices')])
 
     def test_set_model_data(self):
         variance_threshold_selector = VarianceThresholdSelector() \
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py
index 9ca499b..0a1a418 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py
@@ -103,13 +103,17 @@ class VectorIndexerTest(PyFlinkMLTestCase):
         self.assertEqual(self.expected_output, predicted_results)
 
     def test_get_model_data(self):
-        vector_indexer = VectorIndexer().set_handle_invalid('keep')
+        vector_indexer = VectorIndexer().set_max_categories(3)
         model = vector_indexer.fit(self.train_table)
         model_data = model.get_model_data()[0]
         expected_field_names = ['categoryMaps']
         self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
 
-        # TODO: Add test to collect and verify the model data results after FLINK-30124 is resolved.
+        model_rows = [result for result in
+                      self.t_env.to_data_stream(model_data).execute_and_collect()]
+        self.assertEqual(1, len(model_rows))
+        self.assertEqual(
+            {1: {-1.: 1, 0.: 0, 1.: 2}}, model_rows[0][expected_field_names.index('categoryMaps')])
 
     def test_set_model_data(self):
         vector_indexer = VectorIndexer().set_handle_invalid('keep')
diff --git a/flink-ml-python/src/main/java/org/apache/flink/ml/python/PythonBridgeUtils.java b/flink-ml-python/src/main/java/org/apache/flink/ml/python/PythonBridgeUtils.java
new file mode 100644
index 0000000..9b9596e
--- /dev/null
+++ b/flink-ml-python/src/main/java/org/apache/flink/ml/python/PythonBridgeUtils.java
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.python;
+
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.typeutils.ListTypeInfo;
+import org.apache.flink.api.java.typeutils.MapTypeInfo;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
+import org.apache.flink.api.python.shaded.net.razorvine.pickle.Pickler;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.typeinfo.python.PickledByteArrayTypeInfo;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.sql.Date;
+import java.sql.Time;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.api.common.typeinfo.BasicTypeInfo.FLOAT_TYPE_INFO;
+import static org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo.DATE;
+import static org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo.TIME;
+
+/**
+ * Utility functions used to override PyFlink methods to provide a temporary solution for certain
+ * bugs.
+ */
+// TODO: Remove this class after Flink ML depends on a Flink version with FLINK-30168 and
+// FLINK-29477 fixed.
+public class PythonBridgeUtils {
+    public static Object getPickledBytesFromJavaObject(Object obj, TypeInformation<?> dataType)
+            throws IOException {
+        Pickler pickler = new Pickler();
+
+        // triggers the initialization process
+        org.apache.flink.api.common.python.PythonBridgeUtils.getPickledBytesFromJavaObject(
+                null, null);
+
+        if (obj == null) {
+            return new byte[0];
+        } else {
+            if (dataType instanceof SqlTimeTypeInfo) {
+                SqlTimeTypeInfo<?> sqlTimeTypeInfo =
+                        SqlTimeTypeInfo.getInfoFor(dataType.getTypeClass());
+                if (sqlTimeTypeInfo == DATE) {
+                    return pickler.dumps(((Date) obj).toLocalDate().toEpochDay());
+                } else if (sqlTimeTypeInfo == TIME) {
+                    return pickler.dumps(((Time) obj).toLocalTime().toNanoOfDay() / 1000);
+                }
+            } else if (dataType instanceof RowTypeInfo || dataType instanceof TupleTypeInfo) {
+                TypeInformation<?>[] fieldTypes = ((TupleTypeInfoBase<?>) dataType).getFieldTypes();
+                int arity =
+                        dataType instanceof RowTypeInfo
+                                ? ((Row) obj).getArity()
+                                : ((Tuple) obj).getArity();
+
+                List<Object> fieldBytes = new ArrayList<>(arity + 1);
+                if (dataType instanceof RowTypeInfo) {
+                    fieldBytes.add(new byte[] {((Row) obj).getKind().toByteValue()});
+                }
+                for (int i = 0; i < arity; i++) {
+                    Object field =
+                            dataType instanceof RowTypeInfo
+                                    ? ((Row) obj).getField(i)
+                                    : ((Tuple) obj).getField(i);
+                    fieldBytes.add(getPickledBytesFromJavaObject(field, fieldTypes[i]));
+                }
+                return fieldBytes;
+            } else if (dataType instanceof BasicArrayTypeInfo
+                    || dataType instanceof PrimitiveArrayTypeInfo
+                    || dataType instanceof ObjectArrayTypeInfo) {
+                Object[] objects;
+                TypeInformation<?> elementType;
+                if (dataType instanceof BasicArrayTypeInfo) {
+                    objects = (Object[]) obj;
+                    elementType = ((BasicArrayTypeInfo<?, ?>) dataType).getComponentInfo();
+                } else if (dataType instanceof PrimitiveArrayTypeInfo) {
+                    objects = primitiveArrayConverter(obj, dataType);
+                    elementType = ((PrimitiveArrayTypeInfo<?>) dataType).getComponentType();
+                } else {
+                    objects = (Object[]) obj;
+                    elementType = ((ObjectArrayTypeInfo<?, ?>) dataType).getComponentInfo();
+                }
+
+                List<Object> serializedElements = new ArrayList<>(objects.length);
+
+                for (Object object : objects) {
+                    serializedElements.add(getPickledBytesFromJavaObject(object, elementType));
+                }
+                return pickler.dumps(serializedElements);
+            } else if (dataType instanceof MapTypeInfo) {
+                List<List<Object>> serializedMapKV = new ArrayList<>(2);
+                Map<Object, Object> mapObj = (Map) obj;
+                List<Object> keyBytesList = new ArrayList<>(mapObj.size());
+                List<Object> valueBytesList = new ArrayList<>(mapObj.size());
+                for (Map.Entry entry : mapObj.entrySet()) {
+                    keyBytesList.add(
+                            getPickledBytesFromJavaObject(
+                                    entry.getKey(), ((MapTypeInfo) dataType).getKeyTypeInfo()));
+                    valueBytesList.add(
+                            getPickledBytesFromJavaObject(
+                                    entry.getValue(), ((MapTypeInfo) dataType).getValueTypeInfo()));
+                }
+                serializedMapKV.add(keyBytesList);
+                serializedMapKV.add(valueBytesList);
+                return pickler.dumps(serializedMapKV);
+            } else if (dataType instanceof ListTypeInfo) {
+                List objects = (List) obj;
+                List<Object> serializedElements = new ArrayList<>(objects.size());
+                TypeInformation elementType = ((ListTypeInfo) dataType).getElementTypeInfo();
+                for (Object object : objects) {
+                    serializedElements.add(getPickledBytesFromJavaObject(object, elementType));
+                }
+                return pickler.dumps(serializedElements);
+            }
+            if (dataType instanceof BasicTypeInfo
+                    && BasicTypeInfo.getInfoFor(dataType.getTypeClass()) == FLOAT_TYPE_INFO) {
+                // Serialization of float type with pickler loses precision.
+                return pickler.dumps(String.valueOf(obj));
+            } else if (dataType instanceof PickledByteArrayTypeInfo
+                    || dataType instanceof BasicTypeInfo) {
+                return pickler.dumps(obj);
+            } else {
+                // other typeinfos will use the corresponding serializer to serialize data.
+                TypeSerializer serializer = dataType.createSerializer(null);
+                ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos();
+                DataOutputViewStreamWrapper baosWrapper = new DataOutputViewStreamWrapper(baos);
+                serializer.serialize(obj, baosWrapper);
+                return pickler.dumps(baos.toByteArray());
+            }
+        }
+    }
+
+    private static Object[] primitiveArrayConverter(
+            Object array, TypeInformation<?> arrayTypeInfo) {
+        Preconditions.checkArgument(arrayTypeInfo instanceof PrimitiveArrayTypeInfo);
+        Preconditions.checkArgument(array.getClass().isArray());
+        Object[] objects;
+        if (PrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+            boolean[] booleans = (boolean[]) array;
+            objects = new Object[booleans.length];
+            for (int i = 0; i < booleans.length; i++) {
+                objects[i] = booleans[i];
+            }
+        } else if (PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+            byte[] bytes = (byte[]) array;
+            objects = new Object[bytes.length];
+            for (int i = 0; i < bytes.length; i++) {
+                objects[i] = bytes[i];
+            }
+        } else if (PrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+            short[] shorts = (short[]) array;
+            objects = new Object[shorts.length];
+            for (int i = 0; i < shorts.length; i++) {
+                objects[i] = shorts[i];
+            }
+        } else if (PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+            int[] ints = (int[]) array;
+            objects = new Object[ints.length];
+            for (int i = 0; i < ints.length; i++) {
+                objects[i] = ints[i];
+            }
+        } else if (PrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+            long[] longs = (long[]) array;
+            objects = new Object[longs.length];
+            for (int i = 0; i < longs.length; i++) {
+                objects[i] = longs[i];
+            }
+        } else if (PrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+            float[] floats = (float[]) array;
+            objects = new Object[floats.length];
+            for (int i = 0; i < floats.length; i++) {
+                objects[i] = floats[i];
+            }
+        } else if (PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+            double[] doubles = (double[]) array;
+            objects = new Object[doubles.length];
+            for (int i = 0; i < doubles.length; i++) {
+                objects[i] = doubles[i];
+            }
+        } else if (PrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+            char[] chars = (char[]) array;
+            objects = new Object[chars.length];
+            for (int i = 0; i < chars.length; i++) {
+                objects[i] = chars[i];
+            }
+        } else {
+            throw new UnsupportedOperationException(
+                    String.format(
+                            "Primitive array of %s is not supported in PyFlink yet",
+                            ((PrimitiveArrayTypeInfo<?>) arrayTypeInfo)
+                                    .getComponentType()
+                                    .getTypeClass()
+                                    .getSimpleName()));
+        }
+        return objects;
+    }
+}