You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by hx...@apache.org on 2021/02/08 01:50:11 UTC
[flink] branch release-1.11 updated: [FLINK-21208][python] Make
Arrow Coder serialize schema info in every batch (#14859)
This is an automated email from the ASF dual-hosted git repository.
hxb pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.11 by this push:
new 6f7fbb3 [FLINK-21208][python] Make Arrow Coder serialize schema info in every batch (#14859)
6f7fbb3 is described below
commit 6f7fbb3702af65c4a63f532a1e8b1e25f6a479d4
Author: HuangXingBo <hx...@gmail.com>
AuthorDate: Mon Feb 8 09:49:56 2021 +0800
[FLINK-21208][python] Make Arrow Coder serialize schema info in every batch (#14859)
* [FLINK-21208][python] Make Arrow Coder serialize schema info in every batch
* fix
* fix2
---
flink-python/pyflink/fn_execution/ResettableIO.py | 8 ++++----
flink-python/pyflink/fn_execution/coder_impl.py | 10 +++++-----
flink-python/pyflink/table/tests/test_udf.py | 1 -
flink-python/pyflink/table/tests/test_udtf.py | 1 -
flink-python/pyflink/testing/test_case_utils.py | 8 ++++++++
.../python/arrow/ArrowPythonScalarFunctionFlatMap.java | 11 ++++++++---
.../scalar/arrow/ArrowPythonScalarFunctionOperator.java | 11 ++++++++---
.../arrow/RowDataArrowPythonScalarFunctionOperator.java | 11 ++++++++---
.../scalar/arrow/AbstractArrowPythonScalarFunctionRunner.java | 7 +++++++
9 files changed, 48 insertions(+), 20 deletions(-)
diff --git a/flink-python/pyflink/fn_execution/ResettableIO.py b/flink-python/pyflink/fn_execution/ResettableIO.py
index ecca3d3..d63815e 100644
--- a/flink-python/pyflink/fn_execution/ResettableIO.py
+++ b/flink-python/pyflink/fn_execution/ResettableIO.py
@@ -26,15 +26,15 @@ class ResettableIO(io.RawIOBase):
def set_input_bytes(self, b):
self._input_bytes = b
self._input_offset = 0
+ self._size = len(b)
def readinto(self, b):
"""
Read up to len(b) bytes into the writable buffer *b* and return
the number of bytes read. If no bytes are available, None is returned.
"""
- input_len = len(self._input_bytes)
output_buffer_len = len(b)
- remaining = input_len - self._input_offset
+ remaining = self._size - self._input_offset
if remaining >= output_buffer_len:
b[:] = self._input_bytes[self._input_offset:self._input_offset + output_buffer_len]
@@ -42,7 +42,7 @@ class ResettableIO(io.RawIOBase):
return output_buffer_len
elif remaining > 0:
b[:remaining] = self._input_bytes[self._input_offset:self._input_offset + remaining]
- self._input_offset = input_len
+ self._input_offset = self._size
return remaining
else:
return None
@@ -66,7 +66,7 @@ class ResettableIO(io.RawIOBase):
return False
def readable(self):
- return True
+ return self._size - self._input_offset
def writable(self):
return True
diff --git a/flink-python/pyflink/fn_execution/coder_impl.py b/flink-python/pyflink/fn_execution/coder_impl.py
index 907e7ca..2bfab69 100644
--- a/flink-python/pyflink/fn_execution/coder_impl.py
+++ b/flink-python/pyflink/fn_execution/coder_impl.py
@@ -432,14 +432,14 @@ class ArrowCoderImpl(StreamCoderImpl):
self._timezone = timezone
self._resettable_io = ResettableIO()
self._batch_reader = ArrowCoderImpl._load_from_stream(self._resettable_io)
- self._batch_writer = pa.RecordBatchStreamWriter(self._resettable_io, self._schema)
self.data_out_stream = create_OutputStream()
self._resettable_io.set_output_stream(self.data_out_stream)
def encode_to_stream(self, iter_cols, out_stream, nested):
data_out_stream = self.data_out_stream
for cols in iter_cols:
- self._batch_writer.write_batch(
+ batch_writer = pa.RecordBatchStreamWriter(self._resettable_io, self._schema)
+ batch_writer.write_batch(
pandas_to_arrow(self._schema, self._timezone, self._field_types, cols))
out_stream.write_var_int64(data_out_stream.size())
out_stream.write(data_out_stream.get())
@@ -451,9 +451,9 @@ class ArrowCoderImpl(StreamCoderImpl):
@staticmethod
def _load_from_stream(stream):
- reader = pa.ipc.open_stream(stream)
- for batch in reader:
- yield batch
+ while stream.readable():
+ reader = pa.ipc.open_stream(stream)
+ yield reader.read_next_batch()
def _decode_one_batch_from_stream(self, in_stream: create_InputStream) -> List:
self._resettable_io.set_input_bytes(in_stream.read_all(True))
diff --git a/flink-python/pyflink/table/tests/test_udf.py b/flink-python/pyflink/table/tests/test_udf.py
index bc9860e..75926d2 100644
--- a/flink-python/pyflink/table/tests/test_udf.py
+++ b/flink-python/pyflink/table/tests/test_udf.py
@@ -629,7 +629,6 @@ class Subtract(ScalarFunction, unittest.TestCase):
# counter
self.counter.inc(i)
self.counter_sum += i
- self.assertEqual(self.counter_sum, self.counter.get_count())
return i - self.subtracted_value
diff --git a/flink-python/pyflink/table/tests/test_udtf.py b/flink-python/pyflink/table/tests/test_udtf.py
index 3f335f3..608d449 100644
--- a/flink-python/pyflink/table/tests/test_udtf.py
+++ b/flink-python/pyflink/table/tests/test_udtf.py
@@ -124,7 +124,6 @@ class MultiEmit(TableFunction, unittest.TestCase):
def eval(self, x, y):
self.counter.inc(y)
self.counter_sum += y
- self.assertEqual(self.counter_sum, self.counter.get_count())
for i in range(y):
yield x, i
diff --git a/flink-python/pyflink/testing/test_case_utils.py b/flink-python/pyflink/testing/test_case_utils.py
index ac562db..571180d 100644
--- a/flink-python/pyflink/testing/test_case_utils.py
+++ b/flink-python/pyflink/testing/test_case_utils.py
@@ -130,6 +130,8 @@ class PyFlinkStreamTableTestCase(PyFlinkTestCase):
.in_streaming_mode().use_old_planner().build())
self.t_env.get_config().get_configuration().set_string(
"taskmanager.memory.task.off-heap.size", "80mb")
+ self.t_env.get_config().get_configuration().set_string(
+ "python.fn-execution.bundle.size", "1")
class PyFlinkBatchTableTestCase(PyFlinkTestCase):
@@ -144,6 +146,8 @@ class PyFlinkBatchTableTestCase(PyFlinkTestCase):
self.t_env = BatchTableEnvironment.create(self.env, TableConfig())
self.t_env.get_config().get_configuration().set_string(
"taskmanager.memory.task.off-heap.size", "80mb")
+ self.t_env.get_config().get_configuration().set_string(
+ "python.fn-execution.bundle.size", "1")
def collect(self, table):
j_table = table._j_table
@@ -168,6 +172,8 @@ class PyFlinkBlinkStreamTableTestCase(PyFlinkTestCase):
.in_streaming_mode().use_blink_planner().build())
self.t_env.get_config().get_configuration().set_string(
"taskmanager.memory.task.off-heap.size", "80mb")
+ self.t_env.get_config().get_configuration().set_string(
+ "python.fn-execution.bundle.size", "1")
class PyFlinkBlinkBatchTableTestCase(PyFlinkTestCase):
@@ -183,6 +189,8 @@ class PyFlinkBlinkBatchTableTestCase(PyFlinkTestCase):
self.t_env.get_config().get_configuration().set_string(
"taskmanager.memory.task.off-heap.size", "80mb")
self.t_env._j_tenv.getPlanner().getExecEnv().setParallelism(2)
+ self.t_env.get_config().get_configuration().set_string(
+ "python.fn-execution.bundle.size", "1")
class PythonAPICompletenessTestCase(object):
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/functions/python/arrow/ArrowPythonScalarFunctionFlatMap.java b/flink-python/src/main/java/org/apache/flink/table/runtime/functions/python/arrow/ArrowPythonScalarFunctionFlatMap.java
index c6534ae..341963f 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/functions/python/arrow/ArrowPythonScalarFunctionFlatMap.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/functions/python/arrow/ArrowPythonScalarFunctionFlatMap.java
@@ -113,12 +113,17 @@ public final class ArrowPythonScalarFunctionFlatMap extends AbstractPythonScalar
bais.setBuffer(udfResult, 0, udfResult.length);
reader.loadNextBatch();
VectorSchemaRoot root = reader.getVectorSchemaRoot();
- if (arrowReader == null) {
- arrowReader = ArrowUtils.createRowArrowReader(root, userDefinedFunctionOutputType);
- }
+ arrowReader = ArrowUtils.createRowArrowReader(root, userDefinedFunctionOutputType);
for (int i = 0; i < root.getRowCount(); i++) {
resultCollector.collect(Row.join(forwardedInputQueue.poll(), arrowReader.read(i)));
}
+ resetReader();
}
}
+
+ private void resetReader() throws IOException {
+ arrowReader = null;
+ reader.close();
+ reader = new ArrowStreamReader(bais, allocator);
+ }
}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/ArrowPythonScalarFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/ArrowPythonScalarFunctionOperator.java
index 508c6e8..f1d42b9 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/ArrowPythonScalarFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/ArrowPythonScalarFunctionOperator.java
@@ -109,14 +109,19 @@ public class ArrowPythonScalarFunctionOperator extends AbstractRowPythonScalarFu
bais.setBuffer(udfResult, 0, udfResult.length);
reader.loadNextBatch();
VectorSchemaRoot root = reader.getVectorSchemaRoot();
- if (arrowReader == null) {
- arrowReader = ArrowUtils.createRowArrowReader(root, outputType);
- }
+ arrowReader = ArrowUtils.createRowArrowReader(root, outputType);
for (int i = 0; i < root.getRowCount(); i++) {
CRow input = forwardedInputQueue.poll();
cRowWrapper.setChange(input.change());
cRowWrapper.collect(Row.join(input.row(), arrowReader.read(i)));
}
+ resetReader();
}
}
+
+ private void resetReader() throws IOException {
+ arrowReader = null;
+ reader.close();
+ reader = new ArrowStreamReader(bais, allocator);
+ }
}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/RowDataArrowPythonScalarFunctionOperator.java b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/RowDataArrowPythonScalarFunctionOperator.java
index 94a1fe2..017fbbe 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/RowDataArrowPythonScalarFunctionOperator.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/operators/python/scalar/arrow/RowDataArrowPythonScalarFunctionOperator.java
@@ -109,14 +109,19 @@ public class RowDataArrowPythonScalarFunctionOperator
bais.setBuffer(udfResult, 0, udfResult.length);
reader.loadNextBatch();
VectorSchemaRoot root = reader.getVectorSchemaRoot();
- if (arrowReader == null) {
- arrowReader = ArrowUtils.createRowDataArrowReader(root, outputType);
- }
+ arrowReader = ArrowUtils.createRowDataArrowReader(root, outputType);
for (int i = 0; i < root.getRowCount(); i++) {
RowData input = forwardedInputQueue.poll();
reuseJoinedRow.setRowKind(input.getRowKind());
rowDataWrapper.collect(reuseJoinedRow.replace(input, arrowReader.read(i)));
}
+ resetReader();
}
}
+
+ private void resetReader() throws IOException {
+ arrowReader = null;
+ reader.close();
+ reader = new ArrowStreamReader(bais, allocator);
+ }
}
diff --git a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/scalar/arrow/AbstractArrowPythonScalarFunctionRunner.java b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/scalar/arrow/AbstractArrowPythonScalarFunctionRunner.java
index 22afb0c..c19fa32 100644
--- a/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/scalar/arrow/AbstractArrowPythonScalarFunctionRunner.java
+++ b/flink-python/src/main/java/org/apache/flink/table/runtime/runners/python/scalar/arrow/AbstractArrowPythonScalarFunctionRunner.java
@@ -37,6 +37,7 @@ import org.apache.beam.runners.fnexecution.control.OutputReceiverFactory;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.util.WindowedValue;
+import java.io.IOException;
import java.util.Map;
/**
@@ -173,7 +174,13 @@ public abstract class AbstractArrowPythonScalarFunctionRunner<IN>
mainInputReceiver.accept(WindowedValue.valueInGlobalWindow(baos.toByteArray()));
baos.reset();
+ resetWriter();
}
currentBatchCount = 0;
}
+
+ private void resetWriter() throws IOException {
+ arrowStreamWriter = new ArrowStreamWriter(root, null, baos);
+ arrowStreamWriter.start();
+ }
}