You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2022/08/03 11:28:11 UTC
[flink] branch master updated: [FLINK-28740][python][format] Support CsvBulkWriter
This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 17e1920bcd2 [FLINK-28740][python][format] Support CsvBulkWriter
17e1920bcd2 is described below
commit 17e1920bcd2780e12f5ad318faa0db68c4f7fec0
Author: Juntao Hu <ma...@gmail.com>
AuthorDate: Mon Jul 25 15:59:45 2022 +0800
[FLINK-28740][python][format] Support CsvBulkWriter
This closes #20391.
---
.../docs/connectors/datastream/formats/csv.md | 18 +
.../docs/connectors/datastream/formats/csv.md | 18 +
flink-python/pyflink/datastream/__init__.py | 4 +-
.../pyflink/datastream/connectors/__init__.py | 20 +-
flink-python/pyflink/datastream/connectors/base.py | 12 +-
.../pyflink/datastream/connectors/file_system.py | 54 ++-
.../pyflink/datastream/connectors/kafka.py | 23 +-
.../connectors/tests/test_file_system.py | 393 +++++++++++++++------
.../datastream/connectors/tests/test_kafka.py | 7 +-
flink-python/pyflink/datastream/data_stream.py | 8 +-
.../pyflink/datastream/formats/__init__.py | 4 +-
flink-python/pyflink/datastream/formats/csv.py | 67 +++-
...eaderFormatFactory.java => PythonCsvUtils.java} | 34 +-
.../flink/python/util/PythonConnectorUtils.java | 32 ++
14 files changed, 529 insertions(+), 165 deletions(-)
diff --git a/docs/content.zh/docs/connectors/datastream/formats/csv.md b/docs/content.zh/docs/connectors/datastream/formats/csv.md
index dbf8bd6be6c..568757f7137 100644
--- a/docs/content.zh/docs/connectors/datastream/formats/csv.md
+++ b/docs/content.zh/docs/connectors/datastream/formats/csv.md
@@ -137,3 +137,21 @@ The corresponding CSV file:
```
Similarly to the `TextLineInputFormat`, `CsvReaderFormat` can be used in both continues and batch modes (see [TextLineInputFormat]({{< ref "docs/connectors/datastream/formats/text_files" >}}) for examples).
+
+For PyFlink users, `CsvBulkWriter` could be used to create `BulkWriterFactory` to write `Row` records to files in CSV format.
+It should be noted that if the preceding operator of sink is an operator which produces `RowData` records, e.g. CSV source, it needs to be converted to `Row` records before writing to sink.
+```python
+schema = CsvSchema.builder()
+ .add_number_column('id', number_type=DataTypes.BIGINT())
+ .add_array_column('array', separator='#', element_type=DataTypes.INT())
+ .set_column_separator(',')
+ .build()
+
+sink = FileSink.for_bulk_format(
+ OUTPUT_DIR, CsvBulkWriter.for_schema(schema)).build()
+
+# If ds is a source stream producing RowData records, a map could be added to help converting RowData records into Row records.
+ds.map(lambda e: e, output_type=schema.get_type_info()).sink_to(sink)
+# Else
+ds.sink_to(sink)
+```
diff --git a/docs/content/docs/connectors/datastream/formats/csv.md b/docs/content/docs/connectors/datastream/formats/csv.md
index dbf8bd6be6c..3b9a374635d 100644
--- a/docs/content/docs/connectors/datastream/formats/csv.md
+++ b/docs/content/docs/connectors/datastream/formats/csv.md
@@ -137,3 +137,21 @@ The corresponding CSV file:
```
Similarly to the `TextLineInputFormat`, `CsvReaderFormat` can be used in both continues and batch modes (see [TextLineInputFormat]({{< ref "docs/connectors/datastream/formats/text_files" >}}) for examples).
+
+For PyFlink users, `CsvBulkWriter` could be used to create `BulkWriterFactory` to write `Row` records to files in CSV format.
+It should be noted that if the preceding operator of sink is an operator which produces `RowData` records, e.g. CSV source, it needs to be converted to `Row` records before writing to sink.
+```python
+schema = CsvSchema.builder() \
+ .add_number_column('id', number_type=DataTypes.BIGINT()) \
+ .add_array_column('array', separator='#', element_type=DataTypes.INT()) \
+ .set_column_separator(',') \
+ .build()
+
+sink = FileSink.for_bulk_format(
+ OUTPUT_DIR, CsvBulkWriter.for_schema(schema)).build()
+
+# If ds is a source stream producing RowData records, a map could be added to help converting RowData records into Row records.
+ds.map(lambda e: e, output_type=schema.get_type_info()).sink_to(sink)
+# Else
+ds.sink_to(sink)
+```
diff --git a/flink-python/pyflink/datastream/__init__.py b/flink-python/pyflink/datastream/__init__.py
index 982f67c3ad6..85035363b5f 100644
--- a/flink-python/pyflink/datastream/__init__.py
+++ b/flink-python/pyflink/datastream/__init__.py
@@ -200,7 +200,9 @@ Classes to define source & sink:
Classes to define formats used together with source & sink:
- :class:`formats.CsvReaderFormat`:
- A :class:`connectors.StreamFormat` to read csv files into Row data.
+ A :class:`connectors.StreamFormat` to read CSV files into Row data.
+ - :class:`formats.CsvBulkWriter`:
+ Creates :class:`connectors.BulkWriterFactory` to write Row data into CSV files.
- :class:`formats.GenericRecordAvroTypeInfo`:
A :class:`TypeInformation` to indicate vanilla Python records will be translated to
GenericRecordAvroTypeInfo on the Java side.
diff --git a/flink-python/pyflink/datastream/connectors/__init__.py b/flink-python/pyflink/datastream/connectors/__init__.py
index 653507d4db2..c7581cbdf34 100644
--- a/flink-python/pyflink/datastream/connectors/__init__.py
+++ b/flink-python/pyflink/datastream/connectors/__init__.py
@@ -18,11 +18,20 @@
from pyflink.datastream.connectors.base import Sink, Source, DeliveryGuarantee
from pyflink.datastream.connectors.elasticsearch import (Elasticsearch6SinkBuilder,
Elasticsearch7SinkBuilder)
-from pyflink.datastream.connectors.file_system import (FileEnumeratorProvider, FileSink, FileSource,
- BucketAssigner, FileSourceBuilder,
- FileSplitAssignerProvider, OutputFileConfig,
- RollingPolicy,
- StreamFormat, StreamingFileSink, BulkFormat)
+from pyflink.datastream.connectors.file_system import (
+ BucketAssigner,
+ BulkFormat,
+ BulkWriterFactory,
+ FileEnumeratorProvider,
+ FileSink,
+ FileSplitAssignerProvider,
+ FileSource,
+ FileSourceBuilder,
+ OutputFileConfig,
+ RollingPolicy,
+ StreamFormat,
+ StreamingFileSink,
+)
from pyflink.datastream.connectors.jdbc import JdbcSink, JdbcConnectionOptions, JdbcExecutionOptions
from pyflink.datastream.connectors.kafka import (
FlinkKafkaConsumer,
@@ -94,6 +103,7 @@ __all__ = [
'StopCursor',
'BulkFormat',
'StreamFormat',
+ 'BulkWriterFactory',
'StreamingFileSink',
'FlinkKinesisConsumer',
'KinesisStreamsSink',
diff --git a/flink-python/pyflink/datastream/connectors/base.py b/flink-python/pyflink/datastream/connectors/base.py
index a229155bd79..fa8732a2cc6 100644
--- a/flink-python/pyflink/datastream/connectors/base.py
+++ b/flink-python/pyflink/datastream/connectors/base.py
@@ -17,7 +17,7 @@
################################################################################
from abc import ABC, abstractmethod
from enum import Enum
-from typing import Union
+from typing import Union, Optional
from py4j.java_gateway import JavaObject
@@ -53,21 +53,17 @@ class Sink(JavaFunctionWrapper):
super(Sink, self).__init__(sink)
-class TransformAppender(ABC):
+class StreamTransformer(ABC):
@abstractmethod
def apply(self, ds):
pass
-class SupportPreprocessing(ABC):
+class SupportsPreprocessing(ABC):
@abstractmethod
- def need_preprocessing(self) -> bool:
- pass
-
- @abstractmethod
- def get_preprocessing(self) -> 'TransformAppender':
+ def get_transformer(self) -> Optional[StreamTransformer]:
pass
diff --git a/flink-python/pyflink/datastream/connectors/file_system.py b/flink-python/pyflink/datastream/connectors/file_system.py
index d73639bbe2c..df4c360992a 100644
--- a/flink-python/pyflink/datastream/connectors/file_system.py
+++ b/flink-python/pyflink/datastream/connectors/file_system.py
@@ -17,9 +17,15 @@
################################################################################
import warnings
+from typing import TYPE_CHECKING, Optional
+
+if TYPE_CHECKING:
+ from pyflink.table.types import RowType
+
from pyflink.common import Duration, Encoder
-from pyflink.datastream.functions import SinkFunction
from pyflink.datastream.connectors import Source, Sink
+from pyflink.datastream.connectors.base import SupportsPreprocessing, StreamTransformer
+from pyflink.datastream.functions import SinkFunction
from pyflink.datastream.utils import JavaObjectWrapper
from pyflink.java_gateway import get_gateway
from pyflink.util.java_utils import to_jarray
@@ -161,6 +167,20 @@ class BulkWriterFactory(JavaObjectWrapper):
super().__init__(j_bulk_writer_factory)
+class RowDataBulkWriterFactory(BulkWriterFactory):
+ """
+ A :class:`BulkWriterFactory` that receives records with RowData type. This is for indicating
+ that Row record from Python must be first converted to RowData.
+ """
+
+ def __init__(self, j_bulk_writer_factory, row_type: 'RowType'):
+ super().__init__(j_bulk_writer_factory)
+ self._row_type = row_type
+
+ def get_row_type(self) -> 'RowType':
+ return self._row_type
+
+
class FileSourceBuilder(object):
"""
The builder for the :class:`~pyflink.datastream.connectors.FileSource`, to configure the
@@ -479,7 +499,7 @@ class OutputFileConfig(object):
return OutputFileConfig(self.part_prefix, self.part_suffix)
-class FileSink(Sink):
+class FileSink(Sink, SupportsPreprocessing):
"""
A unified sink that emits its input elements to FileSystem files within buckets. This
sink achieves exactly-once semantics for both BATCH and STREAMING.
@@ -526,8 +546,12 @@ class FileSink(Sink):
the checkpoint from which we restore.
"""
- def __init__(self, j_file_sink):
+ def __init__(self, j_file_sink, transformer: Optional[StreamTransformer] = None):
super(FileSink, self).__init__(sink=j_file_sink)
+ self._transformer = transformer
+
+ def get_transformer(self) -> Optional[StreamTransformer]:
+ return self._transformer
class RowFormatBuilder(object):
"""
@@ -575,6 +599,7 @@ class FileSink(Sink):
def __init__(self, j_bulk_format_builder):
self._j_bulk_format_builder = j_bulk_format_builder
+ self._transformer = None
def with_bucket_check_interval(self, interval: int) -> 'FileSink.BulkFormatBuilder':
"""
@@ -599,8 +624,23 @@ class FileSink(Sink):
output_file_config._j_output_file_config)
return self
+ def _with_row_type(self, row_type: 'RowType') -> 'FileSink.BulkFormatBuilder':
+ from pyflink.datastream.data_stream import DataStream
+ from pyflink.table.types import _to_java_data_type
+
+ class RowRowTransformer(StreamTransformer):
+
+ def apply(self, ds):
+ jvm = get_gateway().jvm
+ j_map_function = jvm.org.apache.flink.python.util.PythonConnectorUtils \
+ .RowRowMapper(_to_java_data_type(row_type))
+ return DataStream(ds._j_data_stream.process(j_map_function))
+
+ self._transformer = RowRowTransformer()
+ return self
+
def build(self) -> 'FileSink':
- return FileSink(self._j_bulk_format_builder.build())
+ return FileSink(self._j_bulk_format_builder.build(), self._transformer)
@staticmethod
def for_bulk_format(base_path: str, writer_factory: BulkWriterFactory) \
@@ -609,9 +649,13 @@ class FileSink(Sink):
j_path = jvm.org.apache.flink.core.fs.Path(base_path)
JFileSink = jvm.org.apache.flink.connector.file.sink.FileSink
- return FileSink.BulkFormatBuilder(
+ builder = FileSink.BulkFormatBuilder(
JFileSink.forBulkFormat(j_path, writer_factory.get_java_object())
)
+ if isinstance(writer_factory, RowDataBulkWriterFactory):
+ return builder._with_row_type(writer_factory.get_row_type())
+ else:
+ return builder
# ---- StreamingFileSink ----
diff --git a/flink-python/pyflink/datastream/connectors/kafka.py b/flink-python/pyflink/datastream/connectors/kafka.py
index 226d3585125..707dd679046 100644
--- a/flink-python/pyflink/datastream/connectors/kafka.py
+++ b/flink-python/pyflink/datastream/connectors/kafka.py
@@ -25,8 +25,8 @@ from py4j.java_gateway import JavaObject, get_java_class
from pyflink.common import DeserializationSchema, TypeInformation, typeinfo, SerializationSchema, \
Types, Row
from pyflink.datastream.connectors import Source, Sink
-from pyflink.datastream.connectors.base import DeliveryGuarantee, SupportPreprocessing, \
- TransformAppender
+from pyflink.datastream.connectors.base import DeliveryGuarantee, SupportsPreprocessing, \
+ StreamTransformer
from pyflink.datastream.functions import SinkFunction, SourceFunction
from pyflink.java_gateway import get_gateway
from pyflink.util.java_utils import to_jarray, get_field, get_field_value
@@ -828,7 +828,7 @@ class KafkaOffsetsInitializer(object):
j_map_wrapper.asMap(), offset_reset_strategy._to_j_offset_reset_strategy()))
-class KafkaSink(Sink, SupportPreprocessing):
+class KafkaSink(Sink, SupportsPreprocessing):
"""
Flink Sink to produce data into a Kafka topic. The sink supports all delivery guarantees
described by :class:`DeliveryGuarantee`.
@@ -853,9 +853,9 @@ class KafkaSink(Sink, SupportPreprocessing):
.. versionadded:: 1.16.0
"""
- def __init__(self, j_kafka_sink, preprocessing: TransformAppender = None):
+ def __init__(self, j_kafka_sink, transformer: Optional[StreamTransformer] = None):
super().__init__(j_kafka_sink)
- self._preprocessing = preprocessing
+ self._transformer = transformer
@staticmethod
def builder() -> 'KafkaSinkBuilder':
@@ -864,11 +864,8 @@ class KafkaSink(Sink, SupportPreprocessing):
"""
return KafkaSinkBuilder()
- def need_preprocessing(self) -> bool:
- return self._preprocessing is not None
-
- def get_preprocessing(self) -> TransformAppender:
- return self._preprocessing
+ def get_transformer(self) -> Optional[StreamTransformer]:
+ return self._transformer
class KafkaSinkBuilder(object):
@@ -1020,8 +1017,8 @@ class KafkaRecordSerializationSchema(SerializationSchema):
_wrap_schema('keySerializationSchema')
_wrap_schema('valueSerializationSchema')
- def _build_preprocessing(self) -> TransformAppender:
- class TopicSelectorTransformAppender(TransformAppender):
+ def _build_preprocessing(self) -> StreamTransformer:
+ class SelectTopicTransformer(StreamTransformer):
def __init__(self, topic_selector: KafkaTopicSelector):
self._topic_selector = topic_selector
@@ -1031,7 +1028,7 @@ class KafkaRecordSerializationSchema(SerializationSchema):
return ds.map(lambda v: Row(self._topic_selector.apply(v), v),
output_type=output_type)
- return TopicSelectorTransformAppender(self._topic_selector)
+ return SelectTopicTransformer(self._topic_selector)
class KafkaRecordSerializationSchemaBuilder(object):
diff --git a/flink-python/pyflink/datastream/connectors/tests/test_file_system.py b/flink-python/pyflink/datastream/connectors/tests/test_file_system.py
index 2e89dd01073..ca3e6d91a69 100644
--- a/flink-python/pyflink/datastream/connectors/tests/test_file_system.py
+++ b/flink-python/pyflink/datastream/connectors/tests/test_file_system.py
@@ -27,7 +27,7 @@ from py4j.java_gateway import java_import, JavaObject
from pyflink.common import Types, Configuration
from pyflink.common.watermark_strategy import WatermarkStrategy
-from pyflink.datastream.formats.csv import CsvSchema, CsvReaderFormat
+from pyflink.datastream.formats.csv import CsvSchema, CsvReaderFormat, CsvBulkWriter
from pyflink.datastream.functions import MapFunction
from pyflink.datastream.connectors.file_system import FileSource, FileSink
from pyflink.datastream.formats.avro import (
@@ -52,140 +52,64 @@ class FileSourceCsvReaderFormatTests(PyFlinkStreamingTestCase):
self.csv_file_name = tempfile.mktemp(suffix='.csv', dir=self.tempdir)
def test_csv_primitive_column(self):
- schema = CsvSchema.builder() \
- .add_number_column('tinyint', DataTypes.TINYINT()) \
- .add_number_column('smallint', DataTypes.SMALLINT()) \
- .add_number_column('int', DataTypes.INT()) \
- .add_number_column('bigint', DataTypes.BIGINT()) \
- .add_number_column('float', DataTypes.FLOAT()) \
- .add_number_column('double', DataTypes.DOUBLE()) \
- .add_number_column('decimal', DataTypes.DECIMAL(2, 0)) \
- .add_boolean_column('boolean') \
- .add_string_column('string') \
- .build()
- with open(self.csv_file_name, 'w') as f:
- f.write('127,')
- f.write('-32767,')
- f.write('2147483647,')
- f.write('-9223372036854775808,')
- f.write('3e38,')
- f.write('2e-308,')
- f.write('1.5,')
- f.write('true,')
- f.write('string\n')
- self._build_csv_job(schema)
+ schema, lines = _create_csv_primitive_column_schema_and_lines()
+ self._build_csv_job(schema, lines)
self.env.execute('test_csv_primitive_column')
- row = self.test_sink.get_results(True, False)[0]
- self.assertEqual(row['tinyint'], 127)
- self.assertEqual(row['smallint'], -32767)
- self.assertEqual(row['int'], 2147483647)
- self.assertEqual(row['bigint'], -9223372036854775808)
- self.assertAlmostEqual(row['float'], 3e38, delta=1e31)
- self.assertAlmostEqual(row['double'], 2e-308, delta=2e-301)
- self.assertAlmostEqual(row['decimal'], 2)
- self.assertEqual(row['boolean'], True)
- self.assertEqual(row['string'], 'string')
+ _check_csv_primitive_column_results(self, self.test_sink.get_results(True, False))
+
+ def test_csv_add_columns_from(self):
+ original_schema, lines = _create_csv_primitive_column_schema_and_lines()
+ schema = CsvSchema.builder().add_columns_from(original_schema).build()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_schema_copy')
+ _check_csv_primitive_column_results(self, self.test_sink.get_results(True, False))
def test_csv_array_column(self):
- schema = CsvSchema.builder() \
- .add_array_column('number_array', separator=';', element_type=DataTypes.INT()) \
- .add_array_column('boolean_array', separator=':', element_type=DataTypes.BOOLEAN()) \
- .add_array_column('string_array', separator=',', element_type=DataTypes.STRING()) \
- .set_column_separator('|') \
- .build()
- with open(self.csv_file_name, 'w') as f:
- f.write('1;2;3|')
- f.write('true:false|')
- f.write('a,b,c\n')
- self._build_csv_job(schema)
+ schema, lines = _create_csv_array_column_schema_and_lines()
+ self._build_csv_job(schema, lines)
self.env.execute('test_csv_array_column')
- row = self.test_sink.get_results(True, False)[0]
- self.assertListEqual(row['number_array'], [1, 2, 3])
- self.assertListEqual(row['boolean_array'], [True, False])
- self.assertListEqual(row['string_array'], ['a', 'b', 'c'])
+ _check_csv_array_column_results(self, self.test_sink.get_results(True, False))
def test_csv_allow_comments(self):
- schema = CsvSchema.builder() \
- .add_string_column('string') \
- .set_allow_comments() \
- .build()
- with open(self.csv_file_name, 'w') as f:
- f.write('a\n')
- f.write('# this is comment\n')
- f.write('b\n')
- self._build_csv_job(schema)
+ schema, lines = _create_csv_allow_comments_schema_and_lines()
+ self._build_csv_job(schema, lines)
self.env.execute('test_csv_allow_comments')
- rows = self.test_sink.get_results(True, False)
- self.assertEqual(rows[0]['string'], 'a')
- self.assertEqual(rows[1]['string'], 'b')
+ _check_csv_allow_comments_results(self, self.test_sink.get_results(True, False))
def test_csv_use_header(self):
- schema = CsvSchema.builder() \
- .add_string_column('string') \
- .add_number_column('number') \
- .set_use_header() \
- .build()
- with open(self.csv_file_name, 'w') as f:
- f.write('h1,h2\n')
- f.write('string,123\n')
- self._build_csv_job(schema)
+ schema, lines = _create_csv_use_header_schema_and_lines()
+ self._build_csv_job(schema, lines)
self.env.execute('test_csv_use_header')
- row = self.test_sink.get_results(True, False)[0]
- self.assertEqual(row['string'], 'string')
- self.assertEqual(row['number'], 123)
+ _check_csv_use_header_results(self, self.test_sink.get_results(True, False))
def test_csv_strict_headers(self):
- schema = CsvSchema.builder() \
- .add_string_column('string') \
- .add_number_column('number') \
- .set_use_header() \
- .set_strict_headers() \
- .build()
- with open(self.csv_file_name, 'w') as f:
- f.write('string,number\n')
- f.write('string,123\n')
- self._build_csv_job(schema)
+ schema, lines = _create_csv_strict_headers_schema_and_lines()
+ self._build_csv_job(schema, lines)
self.env.execute('test_csv_strict_headers')
- row = self.test_sink.get_results(True, False)[0]
- self.assertEqual(row['string'], 'string')
- self.assertEqual(row['number'], 123)
+ _check_csv_strict_headers_results(self, self.test_sink.get_results(True, False))
def test_csv_default_quote_char(self):
- schema = CsvSchema.builder() \
- .add_string_column('string') \
- .build()
- with open(self.csv_file_name, 'w') as f:
- f.write('"string"\n')
- self._build_csv_job(schema)
+ schema, lines = _create_csv_default_quote_char_schema_and_lines()
+ self._build_csv_job(schema, lines)
self.env.execute('test_csv_default_quote_char')
- row = self.test_sink.get_results(True, False)[0]
- self.assertEqual(row['string'], 'string')
+ _check_csv_default_quote_char_results(self, self.test_sink.get_results(True, False))
def test_csv_customize_quote_char(self):
- schema = CsvSchema.builder() \
- .add_string_column('string') \
- .set_quote_char('`') \
- .build()
- with open(self.csv_file_name, 'w') as f:
- f.write('`string`\n')
- self._build_csv_job(schema)
+ schema, lines = _create_csv_customize_quote_char_schema_lines()
+ self._build_csv_job(schema, lines)
self.env.execute('test_csv_customize_quote_char')
- row = self.test_sink.get_results(True, False)[0]
- self.assertEqual(row['string'], 'string')
+ _check_csv_customize_quote_char_results(self, self.test_sink.get_results(True, False))
def test_csv_use_escape_char(self):
- schema = CsvSchema.builder() \
- .add_string_column('string') \
- .set_escape_char('\\') \
- .build()
- with open(self.csv_file_name, 'w') as f:
- f.write('\\"string\\"\n')
- self._build_csv_job(schema)
+ schema, lines = _create_csv_set_escape_char_schema_and_lines()
+ self._build_csv_job(schema, lines)
self.env.execute('test_csv_use_escape_char')
- row = self.test_sink.get_results(True, False)[0]
- self.assertEqual(row['string'], '"string"')
+ _check_csv_set_escape_char_results(self, self.test_sink.get_results(True, False))
- def _build_csv_job(self, schema):
+ def _build_csv_job(self, schema, lines):
+ with open(self.csv_file_name, 'w') as f:
+ for line in lines:
+ f.write(line)
source = FileSource.for_record_stream_format(
CsvReaderFormat.for_schema(schema), self.csv_file_name).build()
ds = self.env.from_source(source, WatermarkStrategy.no_watermarks(), 'csv-source')
@@ -498,6 +422,78 @@ class FileSinkAvroWritersTests(PyFlinkStreamingTestCase):
return records
+class FileSinkCsvBulkWriterTests(PyFlinkStreamingTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.env.set_parallelism(1)
+ self.csv_file_name = tempfile.mktemp(dir=self.tempdir)
+ self.csv_dir_name = tempfile.mkdtemp(dir=self.tempdir)
+
+ def test_csv_primitive_column_write(self):
+ schema, lines = _create_csv_primitive_column_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_primitive_column_write')
+ results = self._read_csv_file()
+ self.assertTrue(len(results) == 1)
+ self.assertEqual(
+ results[0],
+ '127,-32767,2147483647,-9223372036854775808,3.0E38,2.0E-308,2,true,string\n'
+ )
+
+ def test_csv_array_column_write(self):
+ schema, lines = _create_csv_array_column_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_array_column_write')
+ results = self._read_csv_file()
+ self.assertTrue(len(results) == 1)
+ self.assertListEqual(results, lines)
+
+ def test_csv_default_quote_char_write(self):
+ schema, lines = _create_csv_default_quote_char_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_default_quote_char_write')
+ results = self._read_csv_file()
+ self.assertTrue(len(results) == 1)
+ self.assertListEqual(results, lines)
+
+ def test_csv_customize_quote_char_write(self):
+ schema, lines = _create_csv_customize_quote_char_schema_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_customize_quote_char_write')
+ results = self._read_csv_file()
+ self.assertTrue(len(results) == 1)
+ self.assertListEqual(results, lines)
+
+ def test_csv_use_escape_char_write(self):
+ schema, lines = _create_csv_set_escape_char_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_use_escape_char_write')
+ results = self._read_csv_file()
+ self.assertTrue(len(results) == 1)
+ self.assertListEqual(results, ['"string,","""string2"""\n'])
+
+ def _build_csv_job(self, schema: CsvSchema, lines):
+ with open(self.csv_file_name, 'w') as f:
+ for line in lines:
+ f.write(line)
+ source = FileSource.for_record_stream_format(
+ CsvReaderFormat.for_schema(schema), self.csv_file_name
+ ).build()
+ ds = self.env.from_source(source, WatermarkStrategy.no_watermarks(), 'csv-source')
+ sink = FileSink.for_bulk_format(
+ self.csv_dir_name, CsvBulkWriter.for_schema(schema)
+ ).build()
+ ds.map(lambda e: e, output_type=schema.get_type_info()).sink_to(sink)
+
+ def _read_csv_file(self) -> List[str]:
+ lines = []
+ for file in glob.glob(os.path.join(self.csv_dir_name, '**/*')):
+ with open(file, 'r') as f:
+ lines.extend(f.readlines())
+ return lines
+
+
class PassThroughMapFunction(MapFunction):
def map(self, value):
@@ -512,6 +508,179 @@ def _import_avro_classes():
java_import(jvm, prefix + cls)
+def _create_csv_primitive_column_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_number_column('tinyint', DataTypes.TINYINT()) \
+ .add_number_column('smallint', DataTypes.SMALLINT()) \
+ .add_number_column('int', DataTypes.INT()) \
+ .add_number_column('bigint', DataTypes.BIGINT()) \
+ .add_number_column('float', DataTypes.FLOAT()) \
+ .add_number_column('double', DataTypes.DOUBLE()) \
+ .add_number_column('decimal', DataTypes.DECIMAL(2, 0)) \
+ .add_boolean_column('boolean') \
+ .add_string_column('string') \
+ .build()
+ lines = [
+ '127,'
+ '-32767,'
+ '2147483647,'
+ '-9223372036854775808,'
+ '3e38,'
+ '2e-308,'
+ '1.5,'
+ 'true,'
+ 'string\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_primitive_column_results(test, results):
+ row = results[0]
+ test.assertEqual(row['tinyint'], 127)
+ test.assertEqual(row['smallint'], -32767)
+ test.assertEqual(row['int'], 2147483647)
+ test.assertEqual(row['bigint'], -9223372036854775808)
+ test.assertAlmostEqual(row['float'], 3e38, delta=1e31)
+ test.assertAlmostEqual(row['double'], 2e-308, delta=2e-301)
+ test.assertAlmostEqual(row['decimal'], 2)
+ test.assertEqual(row['boolean'], True)
+ test.assertEqual(row['string'], 'string')
+
+
+def _create_csv_array_column_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_array_column('number_array', separator=';', element_type=DataTypes.INT()) \
+ .add_array_column('boolean_array', separator=':', element_type=DataTypes.BOOLEAN()) \
+ .add_array_column('string_array', separator=',', element_type=DataTypes.STRING()) \
+ .set_column_separator('|') \
+ .disable_quote_char() \
+ .build()
+ lines = [
+ '1;2;3|'
+ 'true:false|'
+ 'a,b,c\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_array_column_results(test, results):
+ row = results[0]
+ test.assertListEqual(row['number_array'], [1, 2, 3])
+ test.assertListEqual(row['boolean_array'], [True, False])
+ test.assertListEqual(row['string_array'], ['a', 'b', 'c'])
+
+
+def _create_csv_allow_comments_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_string_column('string') \
+ .set_allow_comments() \
+ .build()
+ lines = [
+ 'a\n',
+ '# this is comment\n',
+ 'b\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_allow_comments_results(test, results):
+ test.assertEqual(results[0]['string'], 'a')
+ test.assertEqual(results[1]['string'], 'b')
+
+
+def _create_csv_use_header_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_string_column('string') \
+ .add_number_column('number') \
+ .set_use_header() \
+ .build()
+ lines = [
+ 'h1,h2\n',
+ 'string,123\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_use_header_results(test, results):
+ row = results[0]
+ test.assertEqual(row['string'], 'string')
+ test.assertEqual(row['number'], 123)
+
+
+def _create_csv_strict_headers_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_string_column('string') \
+ .add_number_column('number') \
+ .set_use_header() \
+ .set_strict_headers() \
+ .build()
+ lines = [
+ 'string,number\n',
+ 'string,123\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_strict_headers_results(test, results):
+ row = results[0]
+ test.assertEqual(row['string'], 'string')
+ test.assertEqual(row['number'], 123)
+
+
+def _create_csv_default_quote_char_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_string_column('string') \
+ .add_string_column('string2') \
+ .set_column_separator('|') \
+ .build()
+ lines = [
+ '"string"|"string2"\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_default_quote_char_results(test, results):
+ row = results[0]
+ test.assertEqual(row['string'], 'string')
+
+
+def _create_csv_customize_quote_char_schema_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_string_column('string') \
+ .add_string_column('string2') \
+ .set_column_separator('|') \
+ .set_quote_char('`') \
+ .build()
+ lines = [
+ '`string`|`string2`\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_customize_quote_char_results(test, results):
+ row = results[0]
+ test.assertEqual(row['string'], 'string')
+
+
+def _create_csv_set_escape_char_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_string_column('string') \
+ .add_string_column('string2') \
+ .set_column_separator(',') \
+ .set_escape_char('\\') \
+ .build()
+ lines = [
+ 'string\\,,\\"string2\\"\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_set_escape_char_results(test, results):
+ row = results[0]
+ test.assertEqual(row['string'], 'string,')
+ test.assertEqual(row['string2'], '"string2"')
+
+
BASIC_SCHEMA = """
{
"type": "record",
diff --git a/flink-python/pyflink/datastream/connectors/tests/test_kafka.py b/flink-python/pyflink/datastream/connectors/tests/test_kafka.py
index 033b5643f67..a0b7761d6a9 100644
--- a/flink-python/pyflink/datastream/connectors/tests/test_kafka.py
+++ b/flink-python/pyflink/datastream/connectors/tests/test_kafka.py
@@ -588,8 +588,7 @@ class MockDataStream(data_stream.DataStream):
def sink_to(self, sink):
ds = self
- from pyflink.datastream.connectors.base import SupportPreprocessing
- if isinstance(sink, SupportPreprocessing):
- if sink.need_preprocessing():
- ds = sink.get_preprocessing().apply(self)
+ from pyflink.datastream.connectors.base import SupportsPreprocessing
+ if isinstance(sink, SupportsPreprocessing) and sink.get_transformer() is not None:
+ ds = sink.get_transformer().apply(self)
return ds
diff --git a/flink-python/pyflink/datastream/data_stream.py b/flink-python/pyflink/datastream/data_stream.py
index 11a96b3acfe..845cd156afc 100644
--- a/flink-python/pyflink/datastream/data_stream.py
+++ b/flink-python/pyflink/datastream/data_stream.py
@@ -830,11 +830,9 @@ class DataStream(object):
"""
ds = self
- from pyflink.datastream.connectors.base import SupportPreprocessing
- if isinstance(sink, SupportPreprocessing):
- preprocessing_sink = cast(SupportPreprocessing, sink)
- if preprocessing_sink.need_preprocessing():
- ds = preprocessing_sink.get_preprocessing().apply(self)
+ from pyflink.datastream.connectors.base import SupportsPreprocessing
+ if isinstance(sink, SupportsPreprocessing) and sink.get_transformer() is not None:
+ ds = sink.get_transformer().apply(self)
return DataStreamSink(ds._j_data_stream.sinkTo(sink.get_java_function()))
diff --git a/flink-python/pyflink/datastream/formats/__init__.py b/flink-python/pyflink/datastream/formats/__init__.py
index 74bea254b73..48d582cfaff 100644
--- a/flink-python/pyflink/datastream/formats/__init__.py
+++ b/flink-python/pyflink/datastream/formats/__init__.py
@@ -16,7 +16,7 @@
# limitations under the License.
################################################################################
from .avro import AvroSchema, AvroInputFormat, AvroWriters, GenericRecordAvroTypeInfo
-from .csv import CsvSchema, CsvReaderFormat
+from .csv import CsvSchema, CsvSchemaBuilder, CsvReaderFormat, CsvBulkWriter
from .parquet import AvroParquetReaders, AvroParquetWriters, ParquetColumnarRowInputFormat
__all__ = [
@@ -25,8 +25,10 @@ __all__ = [
'AvroParquetWriters',
'AvroSchema',
'AvroWriters',
+ 'CsvBulkWriter',
'CsvReaderFormat',
'CsvSchema',
+ 'CsvSchemaBuilder',
'GenericRecordAvroTypeInfo',
'ParquetColumnarRowInputFormat'
]
diff --git a/flink-python/pyflink/datastream/formats/csv.py b/flink-python/pyflink/datastream/formats/csv.py
index 20b8509664b..e475684035c 100644
--- a/flink-python/pyflink/datastream/formats/csv.py
+++ b/flink-python/pyflink/datastream/formats/csv.py
@@ -15,9 +15,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
-from typing import Optional, cast
+from typing import Optional
+from pyflink.common.typeinfo import _from_java_type
from pyflink.datastream.connectors import StreamFormat
+from pyflink.datastream.connectors.file_system import BulkWriterFactory, RowDataBulkWriterFactory
from pyflink.java_gateway import get_gateway
from pyflink.table.types import DataType, DataTypes, _to_java_data_type, RowType, NumericType
@@ -30,9 +32,10 @@ class CsvSchema(object):
.. versionadded:: 1.16.0
"""
- def __init__(self, j_schema, data_type: DataType):
+ def __init__(self, j_schema, row_type: RowType):
self._j_schema = j_schema
- self._data_type = data_type
+ self._row_type = row_type # type: RowType
+ self._type_info = None
@staticmethod
def builder() -> 'CsvSchemaBuilder':
@@ -41,6 +44,14 @@ class CsvSchema(object):
"""
return CsvSchemaBuilder()
+ def get_type_info(self):
+ if self._type_info is None:
+ jvm = get_gateway().jvm
+ j_type_info = jvm.org.apache.flink.table.types.utils.LegacyTypeInfoDataTypeConverter \
+ .toLegacyTypeInfo(_to_java_data_type(self._row_type))
+ self._type_info = _from_java_type(j_type_info)
+ return self._type_info
+
def size(self):
return self._j_schema.size()
@@ -131,7 +142,7 @@ class CsvSchemaBuilder(object):
:param schema: Another :class:`CsvSchema`.
"""
self._j_schema_builder.addColumnsFrom(schema._j_schema)
- for field in cast(RowType, schema._data_type):
+ for field in schema._row_type:
self._fields.append(field)
return self
@@ -293,9 +304,53 @@ class CsvReaderFormat(StreamFormat):
Builds a :class:`CsvReaderFormat` using `CsvSchema`.
"""
jvm = get_gateway().jvm
- j_csv_format = jvm.org.apache.flink.formats.csv.CsvReaderFormatFactory \
+ j_csv_format = jvm.org.apache.flink.formats.csv.PythonCsvUtils \
.createCsvReaderFormat(
schema._j_schema,
- _to_java_data_type(schema._data_type)
+ _to_java_data_type(schema._row_type)
)
return CsvReaderFormat(j_csv_format)
+
+
+class CsvBulkWriter(object):
+ """
+ CsvBulkWriter is for building :class:`BulkWriterFactory` to write Rows with a predefined CSV
+ schema to partitioned files in a bulk fashion.
+
+ Example:
+ ::
+
+ >>> schema = CsvSchema.builder() \\
+ ... .add_number_column('id', number_type=DataTypes.INT()) \\
+ ... .add_string_column('name') \\
+ ... .add_array_column('list', ',', element_type=DataTypes.STRING()) \\
+ ... .set_column_separator('|') \\
+ ... .build()
+ >>> sink = FileSink.for_bulk_format(
+ ... OUTPUT_DIR, CsvBulkWriter.for_schema(schema)).build()
+ >>> # If ds is a source stream, an identity map before sink is required
+ >>> ds.map(lambda e: e, output_type=schema.get_type_info()).sink_to(sink)
+
+ .. versionadded:: 1.16.0
+ """
+
+ @staticmethod
+ def for_schema(schema: 'CsvSchema') -> 'BulkWriterFactory':
+ """
+ Builds a :class:`BulkWriterFactory` for writing records to files in CSV format.
+ """
+ jvm = get_gateway().jvm
+ jackson = jvm.org.apache.flink.shaded.jackson2.com.fasterxml.jackson
+ csv = jvm.org.apache.flink.formats.csv
+
+ j_converter = csv.RowDataToCsvConverters.createRowConverter(
+ _to_java_data_type(schema._row_type).getLogicalType())
+ j_mapper = jackson.dataformat.csv.CsvMapper()
+ j_container = j_mapper.createObjectNode()
+ j_context = csv.PythonCsvUtils.createRowDataToCsvFormatConverterContext(
+ j_mapper, j_container)
+
+ j_factory = csv.PythonCsvUtils.createCsvBulkWriterFactory(
+ j_mapper, schema._j_schema, j_converter, j_context
+ )
+ return RowDataBulkWriterFactory(j_factory, schema._row_type)
diff --git a/flink-python/src/main/java/org/apache/flink/formats/csv/CsvReaderFormatFactory.java b/flink-python/src/main/java/org/apache/flink/formats/csv/PythonCsvUtils.java
similarity index 56%
rename from flink-python/src/main/java/org/apache/flink/formats/csv/CsvReaderFormatFactory.java
rename to flink-python/src/main/java/org/apache/flink/formats/csv/PythonCsvUtils.java
index f5c845a6190..8202d46c800 100644
--- a/flink-python/src/main/java/org/apache/flink/formats/csv/CsvReaderFormatFactory.java
+++ b/flink-python/src/main/java/org/apache/flink/formats/csv/PythonCsvUtils.java
@@ -17,6 +17,8 @@
package org.apache.flink.formats.csv;
+import org.apache.flink.api.common.serialization.BulkWriter;
+import org.apache.flink.formats.common.Converter;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.RowType;
@@ -24,16 +26,20 @@ import org.apache.flink.table.types.logical.utils.LogicalTypeUtils;
import org.apache.flink.util.Preconditions;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
+import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ContainerNode;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.dataformat.csv.CsvMapper;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.dataformat.csv.CsvSchema;
-/** Util for creating a {@link CsvReaderFormat}. */
-public class CsvReaderFormatFactory {
- public static CsvReaderFormat createCsvReaderFormat(CsvSchema schema, DataType dataType) {
+/** Utilities for using CSV format in PyFlink. */
+public class PythonCsvUtils {
+
+ /** Util for creating a {@link CsvReaderFormat}. */
+ public static CsvReaderFormat<Object> createCsvReaderFormat(
+ CsvSchema schema, DataType dataType) {
Preconditions.checkArgument(dataType.getLogicalType() instanceof RowType);
- return new CsvReaderFormat(
- () -> new CsvMapper(),
+ return new CsvReaderFormat<>(
+ CsvMapper::new,
ignored -> schema,
JsonNode.class,
new CsvToRowDataConverters(false)
@@ -42,4 +48,22 @@ public class CsvReaderFormatFactory {
InternalTypeInfo.of(dataType.getLogicalType()),
false);
}
+
+ /**
+ * Util for creating a {@link
+ * RowDataToCsvConverters.RowDataToCsvConverter.RowDataToCsvFormatConverterContext}.
+ */
+ public static RowDataToCsvConverters.RowDataToCsvConverter.RowDataToCsvFormatConverterContext
+ createRowDataToCsvFormatConverterContext(CsvMapper mapper, ContainerNode<?> container) {
+ return new RowDataToCsvConverters.RowDataToCsvConverter.RowDataToCsvFormatConverterContext(
+ mapper, container);
+ }
+
+ /**
+ * Util for creating a {@link BulkWriter.Factory} that wraps {@link CsvBulkWriter#forSchema}.
+ */
+ public static <T, R, C> BulkWriter.Factory<T> createCsvBulkWriterFactory(
+ CsvMapper mapper, CsvSchema schema, Converter<T, R, C> converter, C converterContext) {
+ return (out) -> CsvBulkWriter.forSchema(mapper, schema, converter, converterContext, out);
+ }
}
diff --git a/flink-python/src/main/java/org/apache/flink/python/util/PythonConnectorUtils.java b/flink-python/src/main/java/org/apache/flink/python/util/PythonConnectorUtils.java
index 63ac0bac890..987c8acf7c5 100644
--- a/flink-python/src/main/java/org/apache/flink/python/util/PythonConnectorUtils.java
+++ b/flink-python/src/main/java/org/apache/flink/python/util/PythonConnectorUtils.java
@@ -18,7 +18,13 @@
package org.apache.flink.python.util;
import org.apache.flink.api.common.serialization.SerializationSchema;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.streaming.api.functions.ProcessFunction;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.conversion.RowRowConverter;
+import org.apache.flink.table.types.DataType;
import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import java.io.Serializable;
@@ -94,4 +100,30 @@ public class PythonConnectorUtils {
return wrappedSchema.serialize((T) row.getField(1));
}
}
+
+ /** A {@link ProcessFunction} that convert {@link Row} to {@link RowData}. */
+ public static class RowRowMapper extends ProcessFunction<Row, RowData> {
+
+ private static final long serialVersionUID = 1L;
+ private final DataType dataType;
+ private transient RowRowConverter converter;
+
+ public RowRowMapper(DataType dataType) {
+ this.dataType = dataType;
+ }
+
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+ converter = RowRowConverter.create(dataType);
+ converter.open(getRuntimeContext().getUserCodeClassLoader());
+ }
+
+ @Override
+ public void processElement(
+ Row row, ProcessFunction<Row, RowData>.Context ctx, Collector<RowData> out)
+ throws Exception {
+ out.collect(converter.toInternal(row));
+ }
+ }
}