You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2021/05/21 11:07:18 UTC

[flink] branch master updated: [FLINK-22733][python] DataStream.union should handle properly for KeyedStream in Python DataStream API

This is an automated email from the ASF dual-hosted git repository.

dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 09168af  [FLINK-22733][python] DataStream.union should handle properly for KeyedStream in Python DataStream API
09168af is described below

commit 09168af6c8491dfccd85ece9943123c968957e17
Author: Dian Fu <di...@apache.org>
AuthorDate: Fri May 21 13:51:22 2021 +0800

    [FLINK-22733][python] DataStream.union should handle properly for KeyedStream in Python DataStream API
    
    This closes #15981.
---
 flink-python/pyflink/datastream/data_stream.py     |  5 ++++-
 .../pyflink/datastream/tests/test_data_stream.py   | 23 ++++++++++++++++++++--
 2 files changed, 25 insertions(+), 3 deletions(-)

diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py
index ce85de9..a0a9c90 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -368,7 +368,10 @@ class DataStream(object):
         """
         j_data_streams = []
         for data_stream in streams:
-            j_data_streams.append(data_stream._j_data_stream)
+            if isinstance(data_stream, KeyedStream):
+                j_data_streams.append(data_stream._values()._j_data_stream)
+            else:
+                j_data_streams.append(data_stream._j_data_stream)
         gateway = get_gateway()
         JDataStream = gateway.jvm.org.apache.flink.streaming.api.datastream.DataStream
         j_data_stream_arr = get_gateway().new_array(JDataStream, len(j_data_streams))
diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py
index 994cd60..903d75d 100644
--- a/flink-python/pyflink/datastream/tests/test_data_stream.py
+++ b/flink-python/pyflink/datastream/tests/test_data_stream.py
@@ -877,9 +877,9 @@ class StreamingModeDataStreamTests(DataStreamTests, PyFlinkStreamingTestCase):
         ds_2 = self.env.from_collection([4, 5, 6])
         ds_3 = self.env.from_collection([7, 8, 9])
 
-        united_stream = ds_3.union(ds_1, ds_2)
+        unioned_stream = ds_3.union(ds_1, ds_2)
 
-        united_stream.map(lambda x: x + 1).add_sink(self.test_sink)
+        unioned_stream.map(lambda x: x + 1).add_sink(self.test_sink)
         exec_plan = eval(self.env.get_execution_plan())
         source_ids = []
         union_node_pre_ids = []
@@ -894,6 +894,25 @@ class StreamingModeDataStreamTests(DataStreamTests, PyFlinkStreamingTestCase):
         union_node_pre_ids.sort()
         self.assertEqual(source_ids, union_node_pre_ids)
 
+    def test_keyed_stream_union(self):
+        ds_1 = self.env.from_collection([1, 2, 3])
+        ds_2 = self.env.from_collection([4, 5, 6])
+        unioned_stream = ds_1.key_by(lambda x: x).union(ds_2.key_by(lambda x: x))
+        unioned_stream.add_sink(self.test_sink)
+        exec_plan = eval(self.env.get_execution_plan())
+        expected_union_node_pre_ids = []
+        union_node_pre_ids = []
+        for node in exec_plan['nodes']:
+            if node['type'] == '_keyed_stream_values_operator':
+                expected_union_node_pre_ids.append(node['id'])
+            if node['pact'] == 'Data Sink':
+                for pre in node['predecessors']:
+                    union_node_pre_ids.append(pre['id'])
+
+        expected_union_node_pre_ids.sort()
+        union_node_pre_ids.sort()
+        self.assertEqual(expected_union_node_pre_ids, union_node_pre_ids)
+
     def test_project(self):
         ds = self.env.from_collection([[1, 2, 3, 4], [5, 6, 7, 8]],
                                       type_info=Types.TUPLE(