You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by hx...@apache.org on 2021/02/08 01:50:11 UTC

[flink] branch release-1.11 updated: [FLINK-21208][python] Make Arrow Coder serialize schema info in every batch (#14859)

This is an automated email from the ASF dual-hosted git repository.

hxb pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.11 by this push:
     new 6f7fbb3  [FLINK-21208][python] Make Arrow Coder serialize schema info in every batch (#14859)
6f7fbb3 is described below

commit 6f7fbb3702af65c4a63f532a1e8b1e25f6a479d4
Author: HuangXingBo <hx...@gmail.com>
AuthorDate: Mon Feb 8 09:49:56 2021 +0800

    [FLINK-21208][python] Make Arrow Coder serialize schema info in every batch (#14859)
    
    * [FLINK-21208][python] Make Arrow Coder serialize schema info in every batch
    
    * fix
    
    * fix2
---
 flink-python/pyflink/fn_execution/ResettableIO.py             |  8 ++++----
 flink-python/pyflink/fn_execution/coder_impl.py               | 10 +++++-----
 flink-python/pyflink/table/tests/test_udf.py                  |  1 -
 flink-python/pyflink/table/tests/test_udtf.py                 |  1 -
 flink-python/pyflink/testing/test_case_utils.py               |  8 ++++++++
 .../python/arrow/ArrowPythonScalarFunctionFlatMap.java        | 11 ++++++++---
 .../scalar/arrow/ArrowPythonScalarFunctionOperator.java       | 11 ++++++++---
 .../arrow/RowDataArrowPythonScalarFunctionOperator.java       | 11 ++++++++---
 .../scalar/arrow/AbstractArrowPythonScalarFunctionRunner.java |  7 +++++++
 9 files changed, 48 insertions(+), 20 deletions(-)

diff --git a/flink-python/pyflink/fn_execution/ResettableIO.py b/flink-python/pyflink/fn_execution/ResettableIO.py
index ecca3d3..d63815e 100644
--- a/flink-python/pyflink/fn_execution/ResettableIO.py
+++ b/flink-python/pyflink/fn_execution/ResettableIO.py
@@ -26,15 +26,15 @@ class ResettableIO(io.RawIOBase):
     def set_input_bytes(self, b):
         self._input_bytes = b
         self._input_offset = 0
+        self._size = len(b)
 
     def readinto(self, b):
         """
         Read up to len(b) bytes into the writable buffer *b* and return
         the number of bytes read. If no bytes are available, None is returned.
         """
-        input_len = len(self._input_bytes)
         output_buffer_len = len(b)
-        remaining = input_len - self._input_offset
+        remaining = self._size - self._input_offset
 
         if remaining >= output_buffer_len:
             b[:] = self._input_bytes[self._input_offset:self._input_offset + output_buffer_len]
@@ -42,7 +42,7 @@ class ResettableIO(io.RawIOBase):
             return output_buffer_len
         elif remaining > 0:
             b[:remaining] = self._input_bytes[self._input_offset:self._input_offset + remaining]
-            self._input_offset = input_len
+            self._input_offset = self._size
             return remaining
         else:
             return None
@@ -66,7 +66,7 @@ class ResettableIO(io.RawIOBase):
         return False
 
     def readable(self):
-        return True
+        return self._size - self._input_offset
 
     def writable(self):
         return True
diff --git a/flink-python/pyflink/fn_execution/coder_impl.py b/flink-python/pyflink/fn_execution/coder_impl.py
index 907e7ca..2bfab69 100644
--- a/flink-python/pyflink/fn_execution/coder_impl.py
+++ b/flink-python/pyflink/fn_execution/coder_impl.py
@@ -432,14 +432,14 @@ class ArrowCoderImpl(StreamCoderImpl):
         self._timezone = timezone
         self._resettable_io = ResettableIO()
         self._batch_reader = ArrowCoderImpl._load_from_stream(self._resettable_io)
-        self._batch_writer = pa.RecordBatchStreamWriter(self._resettable_io, self._schema)
         self.data_out_stream = create_OutputStream()
         self._resettable_io.set_output_stream(self.data_out_stream)
 
     def encode_to_stream(self, iter_cols, out_stream, nested):
         data_out_stream = self.data_out_stream
         for cols in iter_cols:
-            self._batch_writer.write_batch(
+            batch_writer = pa.RecordBatchStreamWriter(self._resettable_io, self._schema)
+            batch_writer.write_batch(
                 pandas_to_arrow(self._schema, self._timezone, self._field_types, cols))
             out_stream.write_var_int64(data_out_stream.size())
             out_stream.write(data_out_stream.get())
@@ -451,9 +451,9 @@ class ArrowCoderImpl(StreamCoderImpl):
 
     @staticmethod
     def _load_from_stream(stream):
-        reader = pa.ipc.open_stream(stream)
-        for batch in reader:
-            yield batch
+        while stream.readable():
+            reader = pa.ipc.open_stream(stream)
+            yield reader.read_next_batch()
 
     def _decode_one_batch_from_stream(self, in_stream: create_InputStream) -> List:
         self._resettable_io.set_input_bytes(in_stream.read_all(True))
diff --git a/flink-python/pyflink/table/tests/test_udf.py b/flink-python/pyflink/table/tests/test_udf.py
index bc9860e..75926d2 100644
--- a/flink-python/pyflink/table/tests/test_udf.py
+++ b/flink-python/pyflink/table/tests/test_udf.py
@@ -629,7 +629,6 @@ class Subtract(ScalarFunction, unittest.TestCase):
         # counter
         self.counter.inc(i)
         self.counter_sum += i
-        self.assertEqual(self.counter_sum, self.counter.get_count())
         return i - self.subtracted_value
 
 
diff --git a/flink-python/pyflink/table/tests/test_udtf.py b/flink-python/pyflink/table/tests/test_udtf.py
index 3f335f3..608d449 100644
--- a/flink-python/pyflink/table/tests/test_udtf.py
+++ b/flink-python/pyflink/table/tests/test_udtf.py
@@ -124,7 +124,6 @@ class MultiEmit(TableFunction, unittest.TestCase):
     def eval(self, x, y):
         self.counter.inc(y)
         self.counter_sum += y
-        self.assertEqual(self.counter_sum, self.counter.get_count())
         for i in range(y):
             yield x, i
 
diff --git a/flink-python/pyflink/testing/test_case_utils.py b/flink-python/pyflink/testing/test_case_utils.py
index ac562db..571180d 100644
--- a/flink-python/pyflink/testing/test_case_utils.py
+++ b/flink-python/pyflink/testing/test_case_utils.py
@@ -130,6 +130,8 @@ class PyFlinkStreamTableTestCase(PyFlinkTestCase):
                 .in_streaming_mode().use_old_planner().build())
         self.t_env.get_config().get_configuration().set_string(
             "taskmanager.memory.task.off-heap.size", "80mb")
+        self.t_env.get_config().get_configuration().set_string(
+            "python.fn-execution.bundle.size", "1")
 
 
 class PyFlinkBatchTableTestCase(PyFlinkTestCase):
@@ -144,6 +146,8 @@ class PyFlinkBatchTableTestCase(PyFlinkTestCase):
         self.t_env = BatchTableEnvironment.create(self.env, TableConfig())
         self.t_env.get_config().get_configuration().set_string(
             "taskmanager.memory.task.off-heap.size", "80mb")
+        self.t_env.get_config().get_configuration().set_string(
+            "python.fn-execution.bundle.size", "1")
 
     def collect(self, table):
         j_table = table._j_table
@@ -168,6 +172,8 @@ class PyFlinkBlinkStreamTableTestCase(PyFlinkTestCase):
                 .in_streaming_mode().use_blink_planner().build())
         self.t_env.get_config().get_configuration().set_string(
             "taskmanager.memory.task.off-heap.size", "80mb")
+        self.t_env.get_config().get_configuration().set_string(
+            "python.fn-execution.bundle.size", "1")
 
 
 class PyFlinkBlinkBatchTableTestCase(PyFlinkTestCase):
@@ -183,6 +189,8 @@ class PyFlinkBlinkBatchTableTestCase(PyFlinkTestCase):
         self.t_env.get_config().get_configuration().set_string(
             "taskmanager.memory.task.off-heap.size", "80mb")
         self.t_env._j_tenv.getPlanner().getExecEnv().setParallelism(2)
+        self.t_env.get_config().get_configuration().set_string(
+            "python.fn-execution.bundle.size", "1")
 
 
 class PythonAPICompletenessTestCase(object):
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/functions/python/arrow/ArrowPythonScalarFunctionFlatMap.java b/flink-python/src/main/java/org/apache/flink/table/runtime/functions/python/arrow/ArrowPythonScalarFunctionFlatMap.java
index c6534ae..341963f 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/functions/python/arrow/ArrowPythonScalarFunctionFlatMap.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/functions/python/arrow/ArrowPythonScalarFunctionFlatMap.java
@@ -113,12 +113,17 @@ public final class ArrowPythonScalarFunctionFlatMap extends AbstractPythonScalar
             bais.setBuffer(udfResult, 0, udfResult.length);
             reader.loadNextBatch();
             VectorSchemaRoot root = reader.getVectorSchemaRoot();
-            if (arrowReader == null) {
-                arrowReader = ArrowUtils.createRowArrowReader(root, userDefinedFunctionOutputType);
-            }
+            arrowReader = ArrowUtils.createRowArrowReader(root, userDefinedFunctionOutputType);
             for (int i = 0; i < root.getRowCount(); i++) {
                 resultCollector.collect(Row.join(forwardedInputQueue.poll(), arrowReader.read(i)));
             }
+            resetReader();
         }
     }
+
+    private void resetReader() throws IOException {
+        arrowReader = null;
+        reader.close();
+        reader = new ArrowStreamReader(bais, allocator);
+    }
 }
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/ArrowPythonScalarFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/ArrowPythonScalarFunctionOperator.java
index 508c6e8..f1d42b9 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/ArrowPythonScalarFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/ArrowPythonScalarFunctionOperator.java
@@ -109,14 +109,19 @@ public class ArrowPythonScalarFunctionOperator extends AbstractRowPythonScalarFu
             bais.setBuffer(udfResult, 0, udfResult.length);
             reader.loadNextBatch();
             VectorSchemaRoot root = reader.getVectorSchemaRoot();
-            if (arrowReader == null) {
-                arrowReader = ArrowUtils.createRowArrowReader(root, outputType);
-            }
+            arrowReader = ArrowUtils.createRowArrowReader(root, outputType);
             for (int i = 0; i < root.getRowCount(); i++) {
                 CRow input = forwardedInputQueue.poll();
                 cRowWrapper.setChange(input.change());
                 cRowWrapper.collect(Row.join(input.row(), arrowReader.read(i)));
             }
+            resetReader();
         }
     }
+
+    private void resetReader() throws IOException {
+        arrowReader = null;
+        reader.close();
+        reader = new ArrowStreamReader(bais, allocator);
+    }
 }
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/RowDataArrowPythonScalarFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/RowDataArrowPythonScalarFunctionOperator.java
index 94a1fe2..017fbbe 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/RowDataArrowPythonScalarFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/RowDataArrowPythonScalarFunctionOperator.java
@@ -109,14 +109,19 @@ public class RowDataArrowPythonScalarFunctionOperator
             bais.setBuffer(udfResult, 0, udfResult.length);
             reader.loadNextBatch();
             VectorSchemaRoot root = reader.getVectorSchemaRoot();
-            if (arrowReader == null) {
-                arrowReader = ArrowUtils.createRowDataArrowReader(root, outputType);
-            }
+            arrowReader = ArrowUtils.createRowDataArrowReader(root, outputType);
             for (int i = 0; i < root.getRowCount(); i++) {
                 RowData input = forwardedInputQueue.poll();
                 reuseJoinedRow.setRowKind(input.getRowKind());
                 rowDataWrapper.collect(reuseJoinedRow.replace(input, arrowReader.read(i)));
             }
+            resetReader();
         }
     }
+
+    private void resetReader() throws IOException {
+        arrowReader = null;
+        reader.close();
+        reader = new ArrowStreamReader(bais, allocator);
+    }
 }
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/scalar/arrow/AbstractArrowPythonScalarFunctionRunner.java b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/scalar/arrow/AbstractArrowPythonScalarFunctionRunner.java
index 22afb0c..c19fa32 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/scalar/arrow/AbstractArrowPythonScalarFunctionRunner.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/scalar/arrow/AbstractArrowPythonScalarFunctionRunner.java
@@ -37,6 +37,7 @@ import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.util.WindowedValue;
 
+import java.io.IOException;
 import java.util.Map;
 
 /**
@@ -173,7 +174,13 @@ public abstract class AbstractArrowPythonScalarFunctionRunner<IN>
 
             mainInputReceiver.accept(WindowedValue.valueInGlobalWindow(baos.toByteArray()));
             baos.reset();
+            resetWriter();
         }
         currentBatchCount = 0;
     }
+
+    private void resetWriter() throws IOException {
+        arrowStreamWriter = new ArrowStreamWriter(root, null, baos);
+        arrowStreamWriter.start();
+    }
 }