You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/11/29 09:17:32 UTC
[flink-ml] branch master updated: [FLINK-30124] Support collecting model data with arrays and maps
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new d19b76c [FLINK-30124] Support collecting model data with arrays and maps
d19b76c is described below
commit d19b76cface06604c3656e1dfbaa194c5c9c041c
Author: yunfengzhou-hub <yu...@outlook.com>
AuthorDate: Tue Nov 29 17:17:26 2022 +0800
[FLINK-30124] Support collecting model data with arrays and maps
This closes #181.
---
flink-ml-dist/pom.xml | 6 +
flink-ml-dist/src/main/assemblies/bin.xml | 1 +
.../ml/classification/naivebayes/NaiveBayes.java | 18 +-
.../apache/flink/ml/feature/imputer/Imputer.java | 10 +-
.../ml/feature/vectorindexer/VectorIndexer.java | 13 +-
flink-ml-python/pom.xml | 14 ++
flink-ml-python/pyflink/ml/core/wrapper.py | 38 +++-
.../ml/lib/classification/tests/test_naivebayes.py | 23 ++-
.../pyflink/ml/lib/clustering/tests/test_kmeans.py | 13 +-
.../pyflink/ml/lib/feature/tests/test_idf.py | 11 +-
.../pyflink/ml/lib/feature/tests/test_imputer.py | 9 +-
.../lib/feature/tests/test_indextostringmodel.py | 7 +-
.../ml/lib/feature/tests/test_kbinsdiscretizer.py | 11 +-
.../ml/lib/feature/tests/test_stringindexer.py | 7 +-
.../tests/test_variancethresholdselector.py | 7 +-
.../ml/lib/feature/tests/test_vectorindexer.py | 8 +-
.../apache/flink/ml/python/PythonBridgeUtils.java | 226 +++++++++++++++++++++
17 files changed, 402 insertions(+), 20 deletions(-)
diff --git a/flink-ml-dist/pom.xml b/flink-ml-dist/pom.xml
index 1e04c50..20c8464 100644
--- a/flink-ml-dist/pom.xml
+++ b/flink-ml-dist/pom.xml
@@ -46,6 +46,12 @@ under the License.
<version>${project.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-ml-python</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+
<!-- Stateful Functions Dependencies -->
<dependency>
diff --git a/flink-ml-dist/src/main/assemblies/bin.xml b/flink-ml-dist/src/main/assemblies/bin.xml
index 987e27b..516c55f 100644
--- a/flink-ml-dist/src/main/assemblies/bin.xml
+++ b/flink-ml-dist/src/main/assemblies/bin.xml
@@ -37,6 +37,7 @@ under the License.
<include>org.apache.flink:statefun-flink-core</include>
<include>org.apache.flink:flink-ml-uber</include>
<include>org.apache.flink:flink-ml-examples</include>
+ <include>org.apache.flink:flink-ml-python</include>
</includes>
</dependencySet>
</dependencySets>
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
index a4803c7..e5bd995 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
@@ -28,10 +28,13 @@ import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -104,7 +107,20 @@ public class NaiveBayes
aggregatedArrays, new GenerateModelFunction(smoothing));
modelData.getTransformation().setParallelism(1);
- NaiveBayesModel model = new NaiveBayesModel().setModelData(tEnv.fromDataStream(modelData));
+ Schema schema =
+ Schema.newBuilder()
+ .column(
+ "theta",
+ DataTypes.ARRAY(
+ DataTypes.ARRAY(
+ DataTypes.MAP(
+ DataTypes.DOUBLE(), DataTypes.DOUBLE()))))
+ .column("piArray", DataTypes.of(DenseVectorTypeInfo.INSTANCE))
+ .column("labels", DataTypes.of(DenseVectorTypeInfo.INSTANCE))
+ .build();
+
+ NaiveBayesModel model =
+ new NaiveBayesModel().setModelData(tEnv.fromDataStream(modelData, schema));
ReadWriteUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java
index b0b7e67..9ac4272 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java
@@ -27,6 +27,8 @@ import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -90,7 +92,13 @@ public class Imputer implements Estimator<Imputer, ImputerModel>, ImputerParams<
default:
throw new RuntimeException("Unsupported strategy of Imputer: " + getStrategy());
}
- ImputerModel model = new ImputerModel().setModelData(tEnv.fromDataStream(modelData));
+
+ Schema schema =
+ Schema.newBuilder()
+ .column("surrogates", DataTypes.MAP(DataTypes.STRING(), DataTypes.DOUBLE()))
+ .build();
+ ImputerModel model =
+ new ImputerModel().setModelData(tEnv.fromDataStream(modelData, schema));
ReadWriteUtils.updateExistingParams(model, getParamMap());
return model;
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
index 74d6a6b..810a820 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
@@ -39,6 +39,8 @@ import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -117,8 +119,17 @@ public class VectorIndexer
distinctDoubles.map(new ModelGenerator(maxCategories));
modelData.getTransformation().setParallelism(1);
+ Schema schema =
+ Schema.newBuilder()
+ .column(
+ "categoryMaps",
+ DataTypes.MAP(
+ DataTypes.INT(),
+ DataTypes.MAP(DataTypes.DOUBLE(), DataTypes.INT())))
+ .build();
+
VectorIndexerModel model =
- new VectorIndexerModel().setModelData(tEnv.fromDataStream(modelData));
+ new VectorIndexerModel().setModelData(tEnv.fromDataStream(modelData, schema));
ReadWriteUtils.updateExistingParams(model, paramMap);
return model;
}
diff --git a/flink-ml-python/pom.xml b/flink-ml-python/pom.xml
index 316ec40..dc5459e 100644
--- a/flink-ml-python/pom.xml
+++ b/flink-ml-python/pom.xml
@@ -30,6 +30,20 @@ under the License.
<packaging>jar</packaging>
<dependencies>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-table-common</artifactId>
+ <version>${flink.version}</version>
+ <scope>provided</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-python_2.12</artifactId>
+ <version>${flink.version}</version>
+ <scope>provided</scope>
+ </dependency>
+
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-runtime</artifactId>
diff --git a/flink-ml-python/pyflink/ml/core/wrapper.py b/flink-ml-python/pyflink/ml/core/wrapper.py
index e965c0c..83acd9d 100644
--- a/flink-ml-python/pyflink/ml/core/wrapper.py
+++ b/flink-ml-python/pyflink/ml/core/wrapper.py
@@ -15,12 +15,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
+import pickle
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from py4j.java_gateway import JavaObject, get_java_class
-from pyflink.common import typeinfo, Time
-from pyflink.common.typeinfo import _from_java_type, TypeInformation, _is_instance_of
+from pyflink.common import typeinfo, Time, Row, RowKind
+from pyflink.common.typeinfo import _from_java_type, TypeInformation, _is_instance_of, Types, \
+ ExternalTypeInfo, RowTypeInfo, TupleTypeInfo
+from pyflink.datastream import utils
+from pyflink.datastream.utils import pickled_bytes_to_python_converter
from pyflink.java_gateway import get_gateway
from pyflink.table import Table, StreamTableEnvironment, Expression
from pyflink.util.java_utils import to_jarray
@@ -55,6 +59,36 @@ def _from_java_type_wrapper(j_type_info: JavaObject) -> TypeInformation:
typeinfo._from_java_type = _from_java_type_wrapper
+# TODO: Remove this class after Flink ML depends on a Flink version
+# with FLINK-30168 and FLINK-29477 fixed.
+def convert_to_python_obj_wrapper(data, type_info):
+ if type_info == Types.PICKLED_BYTE_ARRAY():
+ return pickle.loads(data)
+ elif isinstance(type_info, ExternalTypeInfo):
+ return convert_to_python_obj_wrapper(data, type_info._type_info)
+ else:
+ gateway = get_gateway()
+ pickle_bytes = gateway.jvm.org.apache.flink.ml.python.PythonBridgeUtils. \
+ getPickledBytesFromJavaObject(data, type_info.get_java_type_info())
+ if isinstance(type_info, RowTypeInfo) or isinstance(type_info, TupleTypeInfo):
+ field_data = zip(list(pickle_bytes[1:]), type_info.get_field_types())
+ fields = []
+ for data, field_type in field_data:
+ if len(data) == 0:
+ fields.append(None)
+ else:
+ fields.append(pickled_bytes_to_python_converter(data, field_type))
+ if isinstance(type_info, RowTypeInfo):
+ return Row.of_kind(RowKind(int.from_bytes(pickle_bytes[0], 'little')), *fields)
+ else:
+ return tuple(fields)
+ else:
+ return pickled_bytes_to_python_converter(pickle_bytes, type_info)
+
+
+utils.convert_to_python_obj = convert_to_python_obj_wrapper
+
+
class JavaWrapper(ABC):
"""
Wrapper class for a Java object
diff --git a/flink-ml-python/pyflink/ml/lib/classification/tests/test_naivebayes.py b/flink-ml-python/pyflink/ml/lib/classification/tests/test_naivebayes.py
index 16f0c8c..b735a71 100644
--- a/flink-ml-python/pyflink/ml/lib/classification/tests/test_naivebayes.py
+++ b/flink-ml-python/pyflink/ml/lib/classification/tests/test_naivebayes.py
@@ -113,12 +113,31 @@ class NaiveBayesTest(PyFlinkMLTestCase):
self.assertEqual(self.expected_output, actual_output)
def test_get_model_data(self):
- model = self.estimator.fit(self.train_data)
+ train_data = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (Vectors.dense([1, 1.]), 11.),
+ (Vectors.dense([2, 1]), 11.),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['features', 'label'],
+ [DenseVectorTypeInfo(), Types.DOUBLE()])))
+
+ model = self.estimator.fit(train_data)
model_data = model.get_model_data()[0]
expected_field_names = ['theta', 'piArray', 'labels']
self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
- # TODO: Add test to collect and verify the model data results after FLINK-30124 is resolved.
+ model_rows = [result for result in
+ self.t_env.to_data_stream(model_data).execute_and_collect()]
+ self.assertEqual(1, len(model_rows))
+ self.assertListAlmostEqual(
+ [11.], model_rows[0][expected_field_names.index('labels')].to_array(), delta=1e-5)
+ self.assertListAlmostEqual(
+ [0.], model_rows[0][expected_field_names.index('piArray')].to_array(), delta=1e-5)
+ theta = model_rows[0][expected_field_names.index('theta')]
+ self.assertAlmostEqual(-0.6931471805599453, theta[0][0].get(1.0), delta=1e-5)
+ self.assertAlmostEqual(-0.6931471805599453, theta[0][0].get(2.0), delta=1e-5)
+ self.assertAlmostEqual(0.0, theta[0][1].get(1.0), delta=1e-5)
def test_set_model_data(self):
model_a = self.estimator.fit(self.train_data)
diff --git a/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py b/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py
index c63b26b..6b816a1 100644
--- a/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py
+++ b/flink-ml-python/pyflink/ml/lib/clustering/tests/test_kmeans.py
@@ -47,7 +47,7 @@ class KMeansTest(PyFlinkMLTestCase):
self.env.from_collection([
(Vectors.dense([0.0, 0.0]),),
(Vectors.dense([0.0, 0.3]),),
- (Vectors.dense([0.3, 3.0]),),
+ (Vectors.dense([0.3, 0.0]),),
(Vectors.dense([9.0, 0.0]),),
(Vectors.dense([9.0, 0.6]),),
(Vectors.dense([9.6, 0.0]),),
@@ -56,7 +56,7 @@ class KMeansTest(PyFlinkMLTestCase):
['features'],
[DenseVectorTypeInfo()])))
self.expected_groups = [
- {DenseVector([0.0, 0.3]), DenseVector([0.3, 3.0]), DenseVector([0.0, 0.0])},
+ {DenseVector([0.0, 0.3]), DenseVector([0.3, 0.0]), DenseVector([0.0, 0.0])},
{DenseVector([9.6, 0.0]), DenseVector([9.0, 0.0]), DenseVector([9.0, 0.6])}]
def test_param(self):
@@ -153,7 +153,14 @@ class KMeansTest(PyFlinkMLTestCase):
expected_field_names = ['centroids', 'weights']
self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
- # TODO: Add test to collect and verify the model data results after FLINK-30122 is resolved.
+ model_rows = [result for result in
+ self.t_env.to_data_stream(model_data).execute_and_collect()]
+ self.assertEqual(1, len(model_rows))
+ centroids = model_rows[0][expected_field_names.index('centroids')]
+ self.assertEqual(2, len(centroids))
+ centroids.sort(key=lambda x: x.get(0))
+ self.assertListAlmostEqual([0.1, 0.1], centroids[0], delta=1e-5)
+ self.assertListAlmostEqual([9.2, 0.2], centroids[1], delta=1e-5)
def test_set_model_data(self):
kmeans = KMeans().set_max_iter(2).set_k(2)
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py
index 11f1887..7d276d5 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py
@@ -119,8 +119,15 @@ class IDFTest(PyFlinkMLTestCase):
expected_field_names = ['idf', 'docFreq', 'numDocs']
self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
- # TODO: Add test to collect and verify the model data results after Flink dependency
- # is upgraded to 1.15.3, 1.16.0 or a higher version. Related ticket: FLINK-29477
+ model_rows = [result for result in
+ self.t_env.to_data_stream(model_data).execute_and_collect()]
+ self.assertEqual(1, len(model_rows))
+ self.assertEqual(3, model_rows[0][expected_field_names.index('numDocs')])
+ self.assertListEqual([0, 3, 1, 2], model_rows[0][expected_field_names.index('docFreq')])
+ self.assertListAlmostEqual(
+ [1.3862943, 0, 0.6931471, 0.2876820],
+ model_rows[0][expected_field_names.index('idf')].to_array(),
+ delta=self.tolerance)
def test_set_model_data(self):
idf = IDF()
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_imputer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_imputer.py
index 8f045ac..c47b755 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_imputer.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_imputer.py
@@ -69,6 +69,7 @@ class ImputerTest(PyFlinkMLTestCase):
'median': self.expected_median_strategy_output,
'most_frequent': self.expected_most_frequent_strategy_output
}
+ self.eps = 1e-5
def test_param(self):
imputer = Imputer().\
@@ -116,7 +117,13 @@ class ImputerTest(PyFlinkMLTestCase):
expected_field_names = ['surrogates']
self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
- # TODO: Add test to collect and verify the model data results after FLINK-30124 is resolved.
+ model_rows = [result for result in
+ self.t_env.to_data_stream(model_data).execute_and_collect()]
+ self.assertEqual(1, len(model_rows))
+ surrogates = model_rows[0][expected_field_names.index('surrogates')]
+ self.assertAlmostEqual(2.0, surrogates['f1'], delta=self.eps)
+ self.assertAlmostEqual(6.8, surrogates['f2'], delta=self.eps)
+ self.assertAlmostEqual(2.0, surrogates['f3'], delta=self.eps)
def test_set_model_data(self):
imputer = Imputer().\
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_indextostringmodel.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_indextostringmodel.py
index 41ef7bd..bf13ddc 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_indextostringmodel.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_indextostringmodel.py
@@ -85,7 +85,12 @@ class IndexToStringModelTest(PyFlinkMLTestCase):
expected_field_names = ['stringArrays']
self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
- # TODO: Add test to collect and verify the model data results after FLINK-30122 is resolved.
+ model_rows = [result for result in
+ self.t_env.to_data_stream(model_data).execute_and_collect()]
+ self.assertEqual(1, len(model_rows))
+ string_arrays = model_rows[0][expected_field_names.index('stringArrays')]
+ self.assertListEqual(["a", "b", "c", "d"], string_arrays[0])
+ self.assertListEqual(["-1.0", "0.0", "1.0", "2.0"], string_arrays[1])
def test_save_load_and_predict(self):
model = IndexToStringModel() \
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py
index 70bc9a5..4d50c56 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_kbinsdiscretizer.py
@@ -88,6 +88,8 @@ class KBinsDiscretizerTest(PyFlinkMLTestCase):
Vectors.dense(2, 0, 2),
]
+ self.eps = 1e-7
+
def test_param(self):
k_bins_discretizer = KBinsDiscretizer()
@@ -163,7 +165,14 @@ class KBinsDiscretizerTest(PyFlinkMLTestCase):
expected_field_names = ['binEdges']
self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
- # TODO: Add test to collect and verify the model data results after FLINK-30122 is resolved.
+ model_rows = [result for result in
+ self.t_env.to_data_stream(model_data).execute_and_collect()]
+ self.assertEqual(1, len(model_rows))
+ bin_edges = model_rows[0][expected_field_names.index('binEdges')]
+ self.assertEqual(3, len(bin_edges))
+ self.assertListEqual([1, 5, 9, 13], bin_edges[0])
+ self.assertListEqual([4.9e-324, 1.7976931348623157e+308], bin_edges[1])
+ self.assertListEqual([0, 1, 2, 3], bin_edges[2])
def test_set_model_data(self):
k_bins_discretizer = KBinsDiscretizer().set_num_bins(3).set_strategy('uniform')
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_stringindexer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_stringindexer.py
index 0f6204b..673fac6 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_stringindexer.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_stringindexer.py
@@ -129,7 +129,12 @@ class StringIndexerTest(PyFlinkMLTestCase):
expected_field_names = ['stringArrays']
self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
- # TODO: Add test to collect and verify the model data results after FLINK-30122 is resolved.
+ model_rows = [result for result in
+ self.t_env.to_data_stream(model_data).execute_and_collect()]
+ self.assertEqual(1, len(model_rows))
+ string_arrays = model_rows[0][expected_field_names.index('stringArrays')]
+ self.assertListEqual(["a", "b", "c", "d"], string_arrays[0])
+ self.assertListEqual(["-1.0", "0.0", "1.0", "2.0"], string_arrays[1])
def test_set_model_data(self):
string_indexer = StringIndexer() \
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_variancethresholdselector.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_variancethresholdselector.py
index 013b383..428cb79 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_variancethresholdselector.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_variancethresholdselector.py
@@ -121,8 +121,11 @@ class VarianceThresholdSelectorTest(PyFlinkMLTestCase):
expected_field_names = ['numOfFeatures', 'indices']
self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
- # TODO: Add test to collect and verify the model data results after Flink dependency
- # is upgraded to 1.15.3, 1.16.0 or a higher version. Related ticket: FLINK-29477
+ model_rows = [result for result in
+ self.t_env.to_data_stream(model_data).execute_and_collect()]
+ self.assertEqual(1, len(model_rows))
+ self.assertEqual(6, model_rows[0][expected_field_names.index('numOfFeatures')])
+ self.assertListEqual([0, 3, 5], model_rows[0][expected_field_names.index('indices')])
def test_set_model_data(self):
variance_threshold_selector = VarianceThresholdSelector() \
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py
index 9ca499b..0a1a418 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py
@@ -103,13 +103,17 @@ class VectorIndexerTest(PyFlinkMLTestCase):
self.assertEqual(self.expected_output, predicted_results)
def test_get_model_data(self):
- vector_indexer = VectorIndexer().set_handle_invalid('keep')
+ vector_indexer = VectorIndexer().set_max_categories(3)
model = vector_indexer.fit(self.train_table)
model_data = model.get_model_data()[0]
expected_field_names = ['categoryMaps']
self.assertEqual(expected_field_names, model_data.get_schema().get_field_names())
- # TODO: Add test to collect and verify the model data results after FLINK-30124 is resolved.
+ model_rows = [result for result in
+ self.t_env.to_data_stream(model_data).execute_and_collect()]
+ self.assertEqual(1, len(model_rows))
+ self.assertEqual(
+ {1: {-1.: 1, 0.: 0, 1.: 2}}, model_rows[0][expected_field_names.index('categoryMaps')])
def test_set_model_data(self):
vector_indexer = VectorIndexer().set_handle_invalid('keep')
diff --git a/flink-ml-python/src/main/java/org/apache/flink/ml/python/PythonBridgeUtils.java b/flink-ml-python/src/main/java/org/apache/flink/ml/python/PythonBridgeUtils.java
new file mode 100644
index 0000000..9b9596e
--- /dev/null
+++ b/flink-ml-python/src/main/java/org/apache/flink/ml/python/PythonBridgeUtils.java
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.python;
+
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple;
+import org.apache.flink.api.java.typeutils.ListTypeInfo;
+import org.apache.flink.api.java.typeutils.MapTypeInfo;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
+import org.apache.flink.api.python.shaded.net.razorvine.pickle.Pickler;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.typeinfo.python.PickledByteArrayTypeInfo;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.sql.Date;
+import java.sql.Time;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.api.common.typeinfo.BasicTypeInfo.FLOAT_TYPE_INFO;
+import static org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo.DATE;
+import static org.apache.flink.api.common.typeinfo.SqlTimeTypeInfo.TIME;
+
+/**
+ * Utility functions used to override PyFlink methods to provide a temporary solution for certain
+ * bugs.
+ */
+// TODO: Remove this class after Flink ML depends on a Flink version with FLINK-30168 and
+// FLINK-29477 fixed.
+public class PythonBridgeUtils {
+ public static Object getPickledBytesFromJavaObject(Object obj, TypeInformation<?> dataType)
+ throws IOException {
+ Pickler pickler = new Pickler();
+
+ // triggers the initialization process
+ org.apache.flink.api.common.python.PythonBridgeUtils.getPickledBytesFromJavaObject(
+ null, null);
+
+ if (obj == null) {
+ return new byte[0];
+ } else {
+ if (dataType instanceof SqlTimeTypeInfo) {
+ SqlTimeTypeInfo<?> sqlTimeTypeInfo =
+ SqlTimeTypeInfo.getInfoFor(dataType.getTypeClass());
+ if (sqlTimeTypeInfo == DATE) {
+ return pickler.dumps(((Date) obj).toLocalDate().toEpochDay());
+ } else if (sqlTimeTypeInfo == TIME) {
+ return pickler.dumps(((Time) obj).toLocalTime().toNanoOfDay() / 1000);
+ }
+ } else if (dataType instanceof RowTypeInfo || dataType instanceof TupleTypeInfo) {
+ TypeInformation<?>[] fieldTypes = ((TupleTypeInfoBase<?>) dataType).getFieldTypes();
+ int arity =
+ dataType instanceof RowTypeInfo
+ ? ((Row) obj).getArity()
+ : ((Tuple) obj).getArity();
+
+ List<Object> fieldBytes = new ArrayList<>(arity + 1);
+ if (dataType instanceof RowTypeInfo) {
+ fieldBytes.add(new byte[] {((Row) obj).getKind().toByteValue()});
+ }
+ for (int i = 0; i < arity; i++) {
+ Object field =
+ dataType instanceof RowTypeInfo
+ ? ((Row) obj).getField(i)
+ : ((Tuple) obj).getField(i);
+ fieldBytes.add(getPickledBytesFromJavaObject(field, fieldTypes[i]));
+ }
+ return fieldBytes;
+ } else if (dataType instanceof BasicArrayTypeInfo
+ || dataType instanceof PrimitiveArrayTypeInfo
+ || dataType instanceof ObjectArrayTypeInfo) {
+ Object[] objects;
+ TypeInformation<?> elementType;
+ if (dataType instanceof BasicArrayTypeInfo) {
+ objects = (Object[]) obj;
+ elementType = ((BasicArrayTypeInfo<?, ?>) dataType).getComponentInfo();
+ } else if (dataType instanceof PrimitiveArrayTypeInfo) {
+ objects = primitiveArrayConverter(obj, dataType);
+ elementType = ((PrimitiveArrayTypeInfo<?>) dataType).getComponentType();
+ } else {
+ objects = (Object[]) obj;
+ elementType = ((ObjectArrayTypeInfo<?, ?>) dataType).getComponentInfo();
+ }
+
+ List<Object> serializedElements = new ArrayList<>(objects.length);
+
+ for (Object object : objects) {
+ serializedElements.add(getPickledBytesFromJavaObject(object, elementType));
+ }
+ return pickler.dumps(serializedElements);
+ } else if (dataType instanceof MapTypeInfo) {
+ List<List<Object>> serializedMapKV = new ArrayList<>(2);
+ Map<Object, Object> mapObj = (Map) obj;
+ List<Object> keyBytesList = new ArrayList<>(mapObj.size());
+ List<Object> valueBytesList = new ArrayList<>(mapObj.size());
+ for (Map.Entry entry : mapObj.entrySet()) {
+ keyBytesList.add(
+ getPickledBytesFromJavaObject(
+ entry.getKey(), ((MapTypeInfo) dataType).getKeyTypeInfo()));
+ valueBytesList.add(
+ getPickledBytesFromJavaObject(
+ entry.getValue(), ((MapTypeInfo) dataType).getValueTypeInfo()));
+ }
+ serializedMapKV.add(keyBytesList);
+ serializedMapKV.add(valueBytesList);
+ return pickler.dumps(serializedMapKV);
+ } else if (dataType instanceof ListTypeInfo) {
+ List objects = (List) obj;
+ List<Object> serializedElements = new ArrayList<>(objects.size());
+ TypeInformation elementType = ((ListTypeInfo) dataType).getElementTypeInfo();
+ for (Object object : objects) {
+ serializedElements.add(getPickledBytesFromJavaObject(object, elementType));
+ }
+ return pickler.dumps(serializedElements);
+ }
+ if (dataType instanceof BasicTypeInfo
+ && BasicTypeInfo.getInfoFor(dataType.getTypeClass()) == FLOAT_TYPE_INFO) {
+ // Serialization of float type with pickler loses precision.
+ return pickler.dumps(String.valueOf(obj));
+ } else if (dataType instanceof PickledByteArrayTypeInfo
+ || dataType instanceof BasicTypeInfo) {
+ return pickler.dumps(obj);
+ } else {
+ // other typeinfos will use the corresponding serializer to serialize data.
+ TypeSerializer serializer = dataType.createSerializer(null);
+ ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos();
+ DataOutputViewStreamWrapper baosWrapper = new DataOutputViewStreamWrapper(baos);
+ serializer.serialize(obj, baosWrapper);
+ return pickler.dumps(baos.toByteArray());
+ }
+ }
+ }
+
+ private static Object[] primitiveArrayConverter(
+ Object array, TypeInformation<?> arrayTypeInfo) {
+ Preconditions.checkArgument(arrayTypeInfo instanceof PrimitiveArrayTypeInfo);
+ Preconditions.checkArgument(array.getClass().isArray());
+ Object[] objects;
+ if (PrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+ boolean[] booleans = (boolean[]) array;
+ objects = new Object[booleans.length];
+ for (int i = 0; i < booleans.length; i++) {
+ objects[i] = booleans[i];
+ }
+ } else if (PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+ byte[] bytes = (byte[]) array;
+ objects = new Object[bytes.length];
+ for (int i = 0; i < bytes.length; i++) {
+ objects[i] = bytes[i];
+ }
+ } else if (PrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+ short[] shorts = (short[]) array;
+ objects = new Object[shorts.length];
+ for (int i = 0; i < shorts.length; i++) {
+ objects[i] = shorts[i];
+ }
+ } else if (PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+ int[] ints = (int[]) array;
+ objects = new Object[ints.length];
+ for (int i = 0; i < ints.length; i++) {
+ objects[i] = ints[i];
+ }
+ } else if (PrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+ long[] longs = (long[]) array;
+ objects = new Object[longs.length];
+ for (int i = 0; i < longs.length; i++) {
+ objects[i] = longs[i];
+ }
+ } else if (PrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+ float[] floats = (float[]) array;
+ objects = new Object[floats.length];
+ for (int i = 0; i < floats.length; i++) {
+ objects[i] = floats[i];
+ }
+ } else if (PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+ double[] doubles = (double[]) array;
+ objects = new Object[doubles.length];
+ for (int i = 0; i < doubles.length; i++) {
+ objects[i] = doubles[i];
+ }
+ } else if (PrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO.equals(arrayTypeInfo)) {
+ char[] chars = (char[]) array;
+ objects = new Object[chars.length];
+ for (int i = 0; i < chars.length; i++) {
+ objects[i] = chars[i];
+ }
+ } else {
+ throw new UnsupportedOperationException(
+ String.format(
+ "Primitive array of %s is not supported in PyFlink yet",
+ ((PrimitiveArrayTypeInfo<?>) arrayTypeInfo)
+ .getComponentType()
+ .getTypeClass()
+ .getSimpleName()));
+ }
+ return objects;
+ }
+}