You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by di...@apache.org on 2022/08/08 07:52:45 UTC
[flink] 02/02: [hotfix][python][tests] Split the test cases of connectors & formats into separate files
This is an automated email from the ASF dual-hosted git repository.
dianfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 4cf0b81d4f492b49d12aa317f3dfbc6e50d72ec8
Author: Dian Fu <di...@apache.org>
AuthorDate: Mon Aug 8 13:22:45 2022 +0800
[hotfix][python][tests] Split the test cases of connectors & formats into separate files
---
.../datastream/connectors/tests/test_cassandra.py | 49 +
.../datastream/connectors/tests/test_connectors.py | 729 -------------
.../connectors/tests/test_elasticsearch.py | 133 +++
.../connectors/tests/test_file_system.py | 1080 +++-----------------
.../datastream/connectors/tests/test_jdbc.py | 61 ++
.../datastream/connectors/tests/test_kafka.py | 53 +-
.../datastream/connectors/tests/test_kinesis.py | 108 ++
.../datastream/connectors/tests/test_pulsar.py | 212 ++++
.../datastream/connectors/tests/test_rabbitmq.py | 48 +
.../datastream/connectors/tests/test_seq_source.py | 36 +
.../pyflink/datastream/formats/tests/__init__.py | 17 +
.../pyflink/datastream/formats/tests/test_avro.py | 450 ++++++++
.../pyflink/datastream/formats/tests/test_csv.py | 354 +++++++
.../datastream/formats/tests/test_parquet.py | 231 +++++
14 files changed, 1890 insertions(+), 1671 deletions(-)
diff --git a/flink-python/pyflink/datastream/connectors/tests/test_cassandra.py b/flink-python/pyflink/datastream/connectors/tests/test_cassandra.py
new file mode 100644
index 00000000000..94e883a2bb9
--- /dev/null
+++ b/flink-python/pyflink/datastream/connectors/tests/test_cassandra.py
@@ -0,0 +1,49 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from pyflink.common import Types
+from pyflink.datastream.connectors.cassandra import CassandraSink, MapperOptions, ConsistencyLevel
+from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase
+
+
+class CassandraSinkTest(PyFlinkStreamingTestCase):
+
+ def test_cassandra_sink(self):
+ type_info = Types.ROW([Types.STRING(), Types.INT()])
+ ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
+ type_info=type_info)
+ cassandra_sink_builder = CassandraSink.add_sink(ds)
+
+ cassandra_sink = cassandra_sink_builder\
+ .set_host('localhost', 9876) \
+ .set_query('query') \
+ .enable_ignore_null_fields() \
+ .set_mapper_options(MapperOptions()
+ .ttl(1)
+ .timestamp(100)
+ .tracing(True)
+ .if_not_exists(False)
+ .consistency_level(ConsistencyLevel.ANY)
+ .save_null_fields(True)) \
+ .set_max_concurrent_requests(1000) \
+ .build()
+
+ cassandra_sink.name('cassandra_sink').set_parallelism(3)
+
+ plan = eval(self.env.get_execution_plan())
+ self.assertEqual("Sink: cassandra_sink", plan['nodes'][1]['type'])
+ self.assertEqual(3, plan['nodes'][1]['parallelism'])
diff --git a/flink-python/pyflink/datastream/connectors/tests/test_connectors.py b/flink-python/pyflink/datastream/connectors/tests/test_connectors.py
deleted file mode 100644
index 1d12083cd3e..00000000000
--- a/flink-python/pyflink/datastream/connectors/tests/test_connectors.py
+++ /dev/null
@@ -1,729 +0,0 @@
-################################################################################
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-################################################################################
-
-from pyflink.common import typeinfo, Duration, WatermarkStrategy, ConfigOptions
-from pyflink.common.serialization import JsonRowDeserializationSchema, \
- JsonRowSerializationSchema, Encoder, SimpleStringSchema
-from pyflink.common.typeinfo import Types
-from pyflink.datastream.connectors import DeliveryGuarantee
-from pyflink.datastream.connectors.cassandra import CassandraSink, MapperOptions, ConsistencyLevel
-from pyflink.datastream.connectors.elasticsearch import Elasticsearch7SinkBuilder, \
- FlushBackoffType, ElasticsearchEmitter
-from pyflink.datastream.connectors.file_system import FileCompactStrategy, FileCompactor, \
- StreamingFileSink, OutputFileConfig, FileSource, StreamFormat, FileEnumeratorProvider, \
- FileSplitAssignerProvider, RollingPolicy, FileSink, BucketAssigner
-from pyflink.datastream.connectors.jdbc import JdbcSink, JdbcConnectionOptions, JdbcExecutionOptions
-from pyflink.datastream.connectors.number_seq import NumberSequenceSource
-from pyflink.datastream.connectors.kafka import FlinkKafkaConsumer, FlinkKafkaProducer
-from pyflink.datastream.connectors.kinesis import PartitionKeyGenerator, FlinkKinesisConsumer, \
- KinesisStreamsSink, KinesisFirehoseSink
-from pyflink.datastream.connectors.pulsar import PulsarSerializationSchema, TopicRoutingMode, \
- MessageDelayer, PulsarSink, PulsarSource, StartCursor, PulsarDeserializationSchema, \
- StopCursor, SubscriptionType
-from pyflink.datastream.connectors.rabbitmq import RMQSink, RMQSource, RMQConnectionConfig
-from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
-from pyflink.java_gateway import get_gateway
-from pyflink.testing.test_case_utils import invoke_java_object_method, PyFlinkStreamingTestCase
-from pyflink.util.java_utils import load_java_class, get_field_value, is_instance_of
-
-
-class FlinkElasticsearch7Test(PyFlinkStreamingTestCase):
-
- def test_es_sink(self):
- ds = self.env.from_collection(
- [{'name': 'ada', 'id': '1'}, {'name': 'luna', 'id': '2'}],
- type_info=Types.MAP(Types.STRING(), Types.STRING()))
-
- es_sink = Elasticsearch7SinkBuilder() \
- .set_emitter(ElasticsearchEmitter.static_index('foo', 'id')) \
- .set_hosts(['localhost:9200']) \
- .set_delivery_guarantee(DeliveryGuarantee.AT_LEAST_ONCE) \
- .set_bulk_flush_max_actions(1) \
- .set_bulk_flush_max_size_mb(2) \
- .set_bulk_flush_interval(1000) \
- .set_bulk_flush_backoff_strategy(FlushBackoffType.CONSTANT, 3, 3000) \
- .set_connection_username('foo') \
- .set_connection_password('bar') \
- .set_connection_path_prefix('foo-bar') \
- .set_connection_request_timeout(30000) \
- .set_connection_timeout(31000) \
- .set_socket_timeout(32000) \
- .build()
-
- j_emitter = get_field_value(es_sink.get_java_function(), 'emitter')
- self.assertTrue(
- is_instance_of(
- j_emitter,
- 'org.apache.flink.connector.elasticsearch.sink.MapElasticsearchEmitter'))
- self.assertEqual(
- get_field_value(
- es_sink.get_java_function(), 'hosts')[0].toString(), 'http://localhost:9200')
- self.assertEqual(
- get_field_value(
- es_sink.get_java_function(), 'deliveryGuarantee').toString(), 'at-least-once')
-
- j_build_bulk_processor_config = get_field_value(
- es_sink.get_java_function(), 'buildBulkProcessorConfig')
- self.assertEqual(j_build_bulk_processor_config.getBulkFlushMaxActions(), 1)
- self.assertEqual(j_build_bulk_processor_config.getBulkFlushMaxMb(), 2)
- self.assertEqual(j_build_bulk_processor_config.getBulkFlushInterval(), 1000)
- self.assertEqual(j_build_bulk_processor_config.getFlushBackoffType().toString(), 'CONSTANT')
- self.assertEqual(j_build_bulk_processor_config.getBulkFlushBackoffRetries(), 3)
- self.assertEqual(j_build_bulk_processor_config.getBulkFlushBackOffDelay(), 3000)
-
- j_network_client_config = get_field_value(
- es_sink.get_java_function(), 'networkClientConfig')
- self.assertEqual(j_network_client_config.getUsername(), 'foo')
- self.assertEqual(j_network_client_config.getPassword(), 'bar')
- self.assertEqual(j_network_client_config.getConnectionRequestTimeout(), 30000)
- self.assertEqual(j_network_client_config.getConnectionTimeout(), 31000)
- self.assertEqual(j_network_client_config.getSocketTimeout(), 32000)
- self.assertEqual(j_network_client_config.getConnectionPathPrefix(), 'foo-bar')
-
- ds.sink_to(es_sink).name('es sink')
-
- def test_es_sink_dynamic(self):
- ds = self.env.from_collection(
- [{'name': 'ada', 'id': '1'}, {'name': 'luna', 'id': '2'}],
- type_info=Types.MAP(Types.STRING(), Types.STRING()))
-
- es_dynamic_index_sink = Elasticsearch7SinkBuilder() \
- .set_emitter(ElasticsearchEmitter.dynamic_index('name', 'id')) \
- .set_hosts(['localhost:9200']) \
- .build()
-
- j_emitter = get_field_value(es_dynamic_index_sink.get_java_function(), 'emitter')
- self.assertTrue(
- is_instance_of(
- j_emitter,
- 'org.apache.flink.connector.elasticsearch.sink.MapElasticsearchEmitter'))
-
- ds.sink_to(es_dynamic_index_sink).name('es dynamic index sink')
-
- def test_es_sink_key_none(self):
- ds = self.env.from_collection(
- [{'name': 'ada', 'id': '1'}, {'name': 'luna', 'id': '2'}],
- type_info=Types.MAP(Types.STRING(), Types.STRING()))
-
- es_sink = Elasticsearch7SinkBuilder() \
- .set_emitter(ElasticsearchEmitter.static_index('foo')) \
- .set_hosts(['localhost:9200']) \
- .build()
-
- j_emitter = get_field_value(es_sink.get_java_function(), 'emitter')
- self.assertTrue(
- is_instance_of(
- j_emitter,
- 'org.apache.flink.connector.elasticsearch.sink.MapElasticsearchEmitter'))
-
- ds.sink_to(es_sink).name('es sink')
-
- def test_es_sink_dynamic_key_none(self):
- ds = self.env.from_collection(
- [{'name': 'ada', 'id': '1'}, {'name': 'luna', 'id': '2'}],
- type_info=Types.MAP(Types.STRING(), Types.STRING()))
-
- es_dynamic_index_sink = Elasticsearch7SinkBuilder() \
- .set_emitter(ElasticsearchEmitter.dynamic_index('name')) \
- .set_hosts(['localhost:9200']) \
- .build()
-
- j_emitter = get_field_value(es_dynamic_index_sink.get_java_function(), 'emitter')
- self.assertTrue(
- is_instance_of(
- j_emitter,
- 'org.apache.flink.connector.elasticsearch.sink.MapElasticsearchEmitter'))
-
- ds.sink_to(es_dynamic_index_sink).name('es dynamic index sink')
-
-
-class FlinkKafkaTest(PyFlinkStreamingTestCase):
-
- def test_kafka_connector_universal(self):
- self.kafka_connector_assertion(FlinkKafkaConsumer, FlinkKafkaProducer)
-
- def kafka_connector_assertion(self, flink_kafka_consumer_clz, flink_kafka_producer_clz):
- source_topic = 'test_source_topic'
- sink_topic = 'test_sink_topic'
- props = {'bootstrap.servers': 'localhost:9092', 'group.id': 'test_group'}
- type_info = Types.ROW([Types.INT(), Types.STRING()])
-
- # Test for kafka consumer
- deserialization_schema = JsonRowDeserializationSchema.builder() \
- .type_info(type_info=type_info).build()
-
- flink_kafka_consumer = flink_kafka_consumer_clz(source_topic, deserialization_schema, props)
- flink_kafka_consumer.set_start_from_earliest()
- flink_kafka_consumer.set_commit_offsets_on_checkpoints(True)
-
- j_properties = get_field_value(flink_kafka_consumer.get_java_function(), 'properties')
- self.assertEqual('localhost:9092', j_properties.getProperty('bootstrap.servers'))
- self.assertEqual('test_group', j_properties.getProperty('group.id'))
- self.assertTrue(get_field_value(flink_kafka_consumer.get_java_function(),
- 'enableCommitOnCheckpoints'))
- j_start_up_mode = get_field_value(flink_kafka_consumer.get_java_function(), 'startupMode')
-
- j_deserializer = get_field_value(flink_kafka_consumer.get_java_function(), 'deserializer')
- j_deserialize_type_info = invoke_java_object_method(j_deserializer, "getProducedType")
- deserialize_type_info = typeinfo._from_java_type(j_deserialize_type_info)
- self.assertTrue(deserialize_type_info == type_info)
- self.assertTrue(j_start_up_mode.equals(get_gateway().jvm
- .org.apache.flink.streaming.connectors
- .kafka.config.StartupMode.EARLIEST))
- j_topic_desc = get_field_value(flink_kafka_consumer.get_java_function(),
- 'topicsDescriptor')
- j_topics = invoke_java_object_method(j_topic_desc, 'getFixedTopics')
- self.assertEqual(['test_source_topic'], list(j_topics))
-
- # Test for kafka producer
- serialization_schema = JsonRowSerializationSchema.builder().with_type_info(type_info) \
- .build()
- flink_kafka_producer = flink_kafka_producer_clz(sink_topic, serialization_schema, props)
- flink_kafka_producer.set_write_timestamp_to_kafka(False)
-
- j_producer_config = get_field_value(flink_kafka_producer.get_java_function(),
- 'producerConfig')
- self.assertEqual('localhost:9092', j_producer_config.getProperty('bootstrap.servers'))
- self.assertEqual('test_group', j_producer_config.getProperty('group.id'))
- self.assertFalse(get_field_value(flink_kafka_producer.get_java_function(),
- 'writeTimestampToKafka'))
-
-
-class FlinkJdbcSinkTest(PyFlinkStreamingTestCase):
-
- def test_jdbc_sink(self):
- ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
- type_info=Types.ROW([Types.STRING(), Types.INT()]))
- jdbc_connection_options = JdbcConnectionOptions.JdbcConnectionOptionsBuilder()\
- .with_driver_name('com.mysql.jdbc.Driver')\
- .with_user_name('root')\
- .with_password('password')\
- .with_url('jdbc:mysql://server-name:server-port/database-name').build()
-
- jdbc_execution_options = JdbcExecutionOptions.builder().with_batch_interval_ms(2000)\
- .with_batch_size(100).with_max_retries(5).build()
- jdbc_sink = JdbcSink.sink("insert into test table", ds.get_type(), jdbc_connection_options,
- jdbc_execution_options)
-
- ds.add_sink(jdbc_sink).name('jdbc sink')
- plan = eval(self.env.get_execution_plan())
- self.assertEqual('Sink: jdbc sink', plan['nodes'][1]['type'])
- j_output_format = get_field_value(jdbc_sink.get_java_function(), 'outputFormat')
-
- connection_options = JdbcConnectionOptions(
- get_field_value(get_field_value(j_output_format, 'connectionProvider'),
- 'jdbcOptions'))
- self.assertEqual(jdbc_connection_options.get_db_url(), connection_options.get_db_url())
- self.assertEqual(jdbc_connection_options.get_driver_name(),
- connection_options.get_driver_name())
- self.assertEqual(jdbc_connection_options.get_password(), connection_options.get_password())
- self.assertEqual(jdbc_connection_options.get_user_name(),
- connection_options.get_user_name())
-
- exec_options = JdbcExecutionOptions(get_field_value(j_output_format, 'executionOptions'))
- self.assertEqual(jdbc_execution_options.get_batch_interval_ms(),
- exec_options.get_batch_interval_ms())
- self.assertEqual(jdbc_execution_options.get_batch_size(),
- exec_options.get_batch_size())
- self.assertEqual(jdbc_execution_options.get_max_retries(),
- exec_options.get_max_retries())
-
-
-class FlinkPulsarTest(PyFlinkStreamingTestCase):
-
- def test_pulsar_source(self):
- TEST_OPTION_NAME = 'pulsar.source.enableAutoAcknowledgeMessage'
- pulsar_source = PulsarSource.builder() \
- .set_service_url('pulsar://localhost:6650') \
- .set_admin_url('http://localhost:8080') \
- .set_topics('ada') \
- .set_start_cursor(StartCursor.earliest()) \
- .set_unbounded_stop_cursor(StopCursor.never()) \
- .set_bounded_stop_cursor(StopCursor.at_publish_time(22)) \
- .set_subscription_name('ff') \
- .set_subscription_type(SubscriptionType.Exclusive) \
- .set_deserialization_schema(
- PulsarDeserializationSchema.flink_type_info(Types.STRING())) \
- .set_deserialization_schema(
- PulsarDeserializationSchema.flink_schema(SimpleStringSchema())) \
- .set_config(TEST_OPTION_NAME, True) \
- .set_properties({'pulsar.source.autoCommitCursorInterval': '1000'}) \
- .build()
-
- ds = self.env.from_source(source=pulsar_source,
- watermark_strategy=WatermarkStrategy.for_monotonous_timestamps(),
- source_name="pulsar source")
- ds.print()
- plan = eval(self.env.get_execution_plan())
- self.assertEqual('Source: pulsar source', plan['nodes'][0]['type'])
-
- configuration = get_field_value(pulsar_source.get_java_function(), "sourceConfiguration")
- self.assertEqual(
- configuration.getString(
- ConfigOptions.key('pulsar.client.serviceUrl')
- .string_type()
- .no_default_value()._j_config_option), 'pulsar://localhost:6650')
- self.assertEqual(
- configuration.getString(
- ConfigOptions.key('pulsar.admin.adminUrl')
- .string_type()
- .no_default_value()._j_config_option), 'http://localhost:8080')
- self.assertEqual(
- configuration.getString(
- ConfigOptions.key('pulsar.consumer.subscriptionName')
- .string_type()
- .no_default_value()._j_config_option), 'ff')
- self.assertEqual(
- configuration.getString(
- ConfigOptions.key('pulsar.consumer.subscriptionType')
- .string_type()
- .no_default_value()._j_config_option), SubscriptionType.Exclusive.name)
- test_option = ConfigOptions.key(TEST_OPTION_NAME).boolean_type().no_default_value()
- self.assertEqual(
- configuration.getBoolean(
- test_option._j_config_option), True)
- self.assertEqual(
- configuration.getLong(
- ConfigOptions.key('pulsar.source.autoCommitCursorInterval')
- .long_type()
- .no_default_value()._j_config_option), 1000)
-
- def test_source_set_topics_with_list(self):
- PulsarSource.builder() \
- .set_service_url('pulsar://localhost:6650') \
- .set_admin_url('http://localhost:8080') \
- .set_topics(['ada', 'beta']) \
- .set_subscription_name('ff') \
- .set_deserialization_schema(
- PulsarDeserializationSchema.flink_schema(SimpleStringSchema())) \
- .build()
-
- def test_source_set_topics_pattern(self):
- PulsarSource.builder() \
- .set_service_url('pulsar://localhost:6650') \
- .set_admin_url('http://localhost:8080') \
- .set_topic_pattern('ada.*') \
- .set_subscription_name('ff') \
- .set_deserialization_schema(
- PulsarDeserializationSchema.flink_schema(SimpleStringSchema())) \
- .build()
-
- def test_source_deprecated_method(self):
- test_option = ConfigOptions.key('pulsar.source.enableAutoAcknowledgeMessage') \
- .boolean_type().no_default_value()
- pulsar_source = PulsarSource.builder() \
- .set_service_url('pulsar://localhost:6650') \
- .set_admin_url('http://localhost:8080') \
- .set_topic_pattern('ada.*') \
- .set_deserialization_schema(
- PulsarDeserializationSchema.flink_type_info(Types.STRING())) \
- .set_unbounded_stop_cursor(StopCursor.at_publish_time(4444)) \
- .set_subscription_name('ff') \
- .set_config(test_option, True) \
- .set_properties({'pulsar.source.autoCommitCursorInterval': '1000'}) \
- .build()
- configuration = get_field_value(pulsar_source.get_java_function(), "sourceConfiguration")
- self.assertEqual(
- configuration.getBoolean(
- test_option._j_config_option), True)
- self.assertEqual(
- configuration.getLong(
- ConfigOptions.key('pulsar.source.autoCommitCursorInterval')
- .long_type()
- .no_default_value()._j_config_option), 1000)
-
- def test_pulsar_sink(self):
- ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
- type_info=Types.ROW([Types.STRING(), Types.INT()]))
-
- TEST_OPTION_NAME = 'pulsar.producer.chunkingEnabled'
- pulsar_sink = PulsarSink.builder() \
- .set_service_url('pulsar://localhost:6650') \
- .set_admin_url('http://localhost:8080') \
- .set_producer_name('fo') \
- .set_topics('ada') \
- .set_serialization_schema(
- PulsarSerializationSchema.flink_schema(SimpleStringSchema())) \
- .set_delivery_guarantee(DeliveryGuarantee.AT_LEAST_ONCE) \
- .set_topic_routing_mode(TopicRoutingMode.ROUND_ROBIN) \
- .delay_sending_message(MessageDelayer.fixed(Duration.of_seconds(12))) \
- .set_config(TEST_OPTION_NAME, True) \
- .set_properties({'pulsar.producer.batchingMaxMessages': '100'}) \
- .build()
-
- ds.sink_to(pulsar_sink).name('pulsar sink')
-
- plan = eval(self.env.get_execution_plan())
- self.assertEqual('pulsar sink: Writer', plan['nodes'][1]['type'])
- configuration = get_field_value(pulsar_sink.get_java_function(), "sinkConfiguration")
- self.assertEqual(
- configuration.getString(
- ConfigOptions.key('pulsar.client.serviceUrl')
- .string_type()
- .no_default_value()._j_config_option), 'pulsar://localhost:6650')
- self.assertEqual(
- configuration.getString(
- ConfigOptions.key('pulsar.admin.adminUrl')
- .string_type()
- .no_default_value()._j_config_option), 'http://localhost:8080')
- self.assertEqual(
- configuration.getString(
- ConfigOptions.key('pulsar.producer.producerName')
- .string_type()
- .no_default_value()._j_config_option), 'fo - %s')
-
- j_pulsar_serialization_schema = get_field_value(
- pulsar_sink.get_java_function(), 'serializationSchema')
- j_serialization_schema = get_field_value(
- j_pulsar_serialization_schema, 'serializationSchema')
- self.assertTrue(
- is_instance_of(
- j_serialization_schema,
- 'org.apache.flink.api.common.serialization.SimpleStringSchema'))
-
- self.assertEqual(
- configuration.getString(
- ConfigOptions.key('pulsar.sink.deliveryGuarantee')
- .string_type()
- .no_default_value()._j_config_option), 'at-least-once')
-
- j_topic_router = get_field_value(pulsar_sink.get_java_function(), "topicRouter")
- self.assertTrue(
- is_instance_of(
- j_topic_router,
- 'org.apache.flink.connector.pulsar.sink.writer.router.RoundRobinTopicRouter'))
-
- j_message_delayer = get_field_value(pulsar_sink.get_java_function(), 'messageDelayer')
- delay_duration = get_field_value(j_message_delayer, 'delayDuration')
- self.assertEqual(delay_duration, 12000)
-
- test_option = ConfigOptions.key(TEST_OPTION_NAME).boolean_type().no_default_value()
- self.assertEqual(
- configuration.getBoolean(
- test_option._j_config_option), True)
- self.assertEqual(
- configuration.getLong(
- ConfigOptions.key('pulsar.producer.batchingMaxMessages')
- .long_type()
- .no_default_value()._j_config_option), 100)
-
- def test_sink_set_topics_with_list(self):
- PulsarSink.builder() \
- .set_service_url('pulsar://localhost:6650') \
- .set_admin_url('http://localhost:8080') \
- .set_topics(['ada', 'beta']) \
- .set_serialization_schema(
- PulsarSerializationSchema.flink_schema(SimpleStringSchema())) \
- .build()
-
-
-class RMQTest(PyFlinkStreamingTestCase):
-
- def test_rabbitmq_connectors(self):
- connection_config = RMQConnectionConfig.Builder() \
- .set_host('localhost') \
- .set_port(5672) \
- .set_virtual_host('/') \
- .set_user_name('guest') \
- .set_password('guest') \
- .build()
- type_info = Types.ROW([Types.INT(), Types.STRING()])
- deserialization_schema = JsonRowDeserializationSchema.builder() \
- .type_info(type_info=type_info).build()
-
- rmq_source = RMQSource(
- connection_config, 'source_queue', True, deserialization_schema)
- self.assertEqual(
- get_field_value(rmq_source.get_java_function(), 'queueName'), 'source_queue')
- self.assertTrue(get_field_value(rmq_source.get_java_function(), 'usesCorrelationId'))
-
- serialization_schema = JsonRowSerializationSchema.builder().with_type_info(type_info) \
- .build()
- rmq_sink = RMQSink(connection_config, 'sink_queue', serialization_schema)
- self.assertEqual(
- get_field_value(rmq_sink.get_java_function(), 'queueName'), 'sink_queue')
-
-
-class ConnectorTests(PyFlinkStreamingTestCase):
-
- def setUp(self) -> None:
- super(ConnectorTests, self).setUp()
- self.test_sink = DataStreamTestSinkFunction()
-
- def tearDown(self) -> None:
- super(ConnectorTests, self).tearDown()
- self.test_sink.clear()
-
- def test_stream_file_sink(self):
- self.env.set_parallelism(2)
- ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
- type_info=Types.ROW([Types.STRING(), Types.INT()]))
- ds.map(
- lambda a: a[0],
- Types.STRING()).add_sink(
- StreamingFileSink.for_row_format(self.tempdir, Encoder.simple_string_encoder())
- .with_rolling_policy(
- RollingPolicy.default_rolling_policy(
- part_size=1024 * 1024 * 1024,
- rollover_interval=15 * 60 * 1000,
- inactivity_interval=5 * 60 * 1000))
- .with_output_file_config(
- OutputFileConfig.OutputFileConfigBuilder()
- .with_part_prefix("prefix")
- .with_part_suffix("suffix").build()).build())
-
- self.env.execute("test_streaming_file_sink")
-
- results = []
- import os
- for root, dirs, files in os.walk(self.tempdir, topdown=True):
- for file in files:
- self.assertTrue(file.startswith('.prefix'))
- self.assertTrue('suffix' in file)
- path = root + "/" + file
- with open(path) as infile:
- for line in infile:
- results.append(line)
-
- expected = ['deeefg\n', 'bdc\n', 'ab\n', 'cfgs\n']
- results.sort()
- expected.sort()
- self.assertEqual(expected, results)
-
- def test_file_source(self):
- stream_format = StreamFormat.text_line_format()
- paths = ["/tmp/1.txt", "/tmp/2.txt"]
- file_source_builder = FileSource.for_record_stream_format(stream_format, *paths)
- file_source = file_source_builder\
- .monitor_continuously(Duration.of_days(1)) \
- .set_file_enumerator(FileEnumeratorProvider.default_splittable_file_enumerator()) \
- .set_split_assigner(FileSplitAssignerProvider.locality_aware_split_assigner()) \
- .build()
-
- continuous_setting = file_source.get_java_function().getContinuousEnumerationSettings()
- self.assertIsNotNone(continuous_setting)
- self.assertEqual(Duration.of_days(1), Duration(continuous_setting.getDiscoveryInterval()))
-
- input_paths_field = \
- load_java_class("org.apache.flink.connector.file.src.AbstractFileSource"). \
- getDeclaredField("inputPaths")
- input_paths_field.setAccessible(True)
- input_paths = input_paths_field.get(file_source.get_java_function())
- self.assertEqual(len(input_paths), len(paths))
- self.assertEqual(str(input_paths[0]), paths[0])
- self.assertEqual(str(input_paths[1]), paths[1])
-
- def test_file_sink(self):
- base_path = "/tmp/1.txt"
- encoder = Encoder.simple_string_encoder()
- file_sink_builder = FileSink.for_row_format(base_path, encoder)
- file_sink = file_sink_builder\
- .with_bucket_check_interval(1000) \
- .with_bucket_assigner(BucketAssigner.base_path_bucket_assigner()) \
- .with_rolling_policy(RollingPolicy.on_checkpoint_rolling_policy()) \
- .with_output_file_config(
- OutputFileConfig.builder().with_part_prefix("pre").with_part_suffix("suf").build())\
- .enable_compact(FileCompactStrategy.builder()
- .enable_compaction_on_checkpoint(3)
- .set_size_threshold(1024)
- .set_num_compact_threads(2)
- .build(),
- FileCompactor.concat_file_compactor(b'\n')) \
- .build()
-
- buckets_builder_field = \
- load_java_class("org.apache.flink.connector.file.sink.FileSink"). \
- getDeclaredField("bucketsBuilder")
- buckets_builder_field.setAccessible(True)
- buckets_builder = buckets_builder_field.get(file_sink.get_java_function())
-
- self.assertEqual("DefaultRowFormatBuilder", buckets_builder.getClass().getSimpleName())
-
- row_format_builder_clz = load_java_class(
- "org.apache.flink.connector.file.sink.FileSink$RowFormatBuilder")
- encoder_field = row_format_builder_clz.getDeclaredField("encoder")
- encoder_field.setAccessible(True)
- self.assertEqual("SimpleStringEncoder",
- encoder_field.get(buckets_builder).getClass().getSimpleName())
-
- interval_field = row_format_builder_clz.getDeclaredField("bucketCheckInterval")
- interval_field.setAccessible(True)
- self.assertEqual(1000, interval_field.get(buckets_builder))
-
- bucket_assigner_field = row_format_builder_clz.getDeclaredField("bucketAssigner")
- bucket_assigner_field.setAccessible(True)
- self.assertEqual("BasePathBucketAssigner",
- bucket_assigner_field.get(buckets_builder).getClass().getSimpleName())
-
- rolling_policy_field = row_format_builder_clz.getDeclaredField("rollingPolicy")
- rolling_policy_field.setAccessible(True)
- self.assertEqual("OnCheckpointRollingPolicy",
- rolling_policy_field.get(buckets_builder).getClass().getSimpleName())
-
- output_file_config_field = row_format_builder_clz.getDeclaredField("outputFileConfig")
- output_file_config_field.setAccessible(True)
- output_file_config = output_file_config_field.get(buckets_builder)
- self.assertEqual("pre", output_file_config.getPartPrefix())
- self.assertEqual("suf", output_file_config.getPartSuffix())
-
- compact_strategy_field = row_format_builder_clz.getDeclaredField("compactStrategy")
- compact_strategy_field.setAccessible(True)
- compact_strategy = compact_strategy_field.get(buckets_builder)
- self.assertEqual(3, compact_strategy.getNumCheckpointsBeforeCompaction())
- self.assertEqual(1024, compact_strategy.getSizeThreshold())
- self.assertEqual(2, compact_strategy.getNumCompactThreads())
-
- file_compactor_field = row_format_builder_clz.getDeclaredField("fileCompactor")
- file_compactor_field.setAccessible(True)
- file_compactor = file_compactor_field.get(buckets_builder)
- self.assertEqual("ConcatFileCompactor", file_compactor.getClass().getSimpleName())
- concat_file_compactor_clz = load_java_class(
- "org.apache.flink.connector.file.sink.compactor.ConcatFileCompactor"
- )
- file_delimiter_field = concat_file_compactor_clz.getDeclaredField("fileDelimiter")
- file_delimiter_field.setAccessible(True)
- file_delimiter = file_delimiter_field.get(file_compactor)
- self.assertEqual(b'\n', file_delimiter)
-
- def test_seq_source(self):
- seq_source = NumberSequenceSource(1, 10)
-
- seq_source_clz = load_java_class(
- "org.apache.flink.api.connector.source.lib.NumberSequenceSource")
- from_field = seq_source_clz.getDeclaredField("from")
- from_field.setAccessible(True)
- self.assertEqual(1, from_field.get(seq_source.get_java_function()))
-
- to_field = seq_source_clz.getDeclaredField("to")
- to_field.setAccessible(True)
- self.assertEqual(10, to_field.get(seq_source.get_java_function()))
-
-
-class FlinkKinesisTest(PyFlinkStreamingTestCase):
-
- def test_kinesis_source(self):
- consumer_config = {
- 'aws.region': 'us-east-1',
- 'aws.credentials.provider.basic.accesskeyid': 'aws_access_key_id',
- 'aws.credentials.provider.basic.secretkey': 'aws_secret_access_key',
- 'flink.stream.initpos': 'LATEST'
- }
-
- kinesis_source = FlinkKinesisConsumer("stream-1", SimpleStringSchema(), consumer_config)
-
- ds = self.env.add_source(source_func=kinesis_source, source_name="kinesis source")
- ds.print()
- plan = eval(self.env.get_execution_plan())
- self.assertEqual('Source: kinesis source', plan['nodes'][0]['type'])
- self.assertEqual(
- get_field_value(kinesis_source.get_java_function(), 'streams')[0], 'stream-1')
-
- def test_kinesis_streams_sink(self):
- sink_properties = {
- 'aws.region': 'us-east-1',
- 'aws.credentials.provider.basic.secretkey': 'aws_secret_access_key'
- }
-
- ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
- type_info=Types.ROW([Types.STRING(), Types.INT()]))
-
- kinesis_streams_sink = KinesisStreamsSink.builder() \
- .set_kinesis_client_properties(sink_properties) \
- .set_serialization_schema(SimpleStringSchema()) \
- .set_partition_key_generator(PartitionKeyGenerator.fixed()) \
- .set_stream_name("stream-1") \
- .set_fail_on_error(False) \
- .set_max_batch_size(500) \
- .set_max_in_flight_requests(50) \
- .set_max_buffered_requests(10000) \
- .set_max_batch_size_in_bytes(5 * 1024 * 1024) \
- .set_max_time_in_buffer_ms(5000) \
- .set_max_record_size_in_bytes(1 * 1024 * 1024) \
- .build()
-
- ds.sink_to(kinesis_streams_sink).name('kinesis streams sink')
- plan = eval(self.env.get_execution_plan())
-
- self.assertEqual('kinesis streams sink: Writer', plan['nodes'][1]['type'])
- self.assertEqual(get_field_value(kinesis_streams_sink.get_java_function(), 'failOnError'),
- False)
- self.assertEqual(
- get_field_value(kinesis_streams_sink.get_java_function(), 'streamName'), 'stream-1')
-
- def test_kinesis_firehose_sink(self):
-
- sink_properties = {
- 'aws.region': 'eu-west-1',
- 'aws.credentials.provider.basic.accesskeyid': 'aws_access_key_id',
- 'aws.credentials.provider.basic.secretkey': 'aws_secret_access_key'
- }
-
- ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
- type_info=Types.ROW([Types.STRING(), Types.INT()]))
-
- kinesis_firehose_sink = KinesisFirehoseSink.builder() \
- .set_firehose_client_properties(sink_properties) \
- .set_serialization_schema(SimpleStringSchema()) \
- .set_delivery_stream_name('stream-1') \
- .set_fail_on_error(False) \
- .set_max_batch_size(500) \
- .set_max_in_flight_requests(50) \
- .set_max_buffered_requests(10000) \
- .set_max_batch_size_in_bytes(5 * 1024 * 1024) \
- .set_max_time_in_buffer_ms(5000) \
- .set_max_record_size_in_bytes(1 * 1024 * 1024) \
- .build()
-
- ds.sink_to(kinesis_firehose_sink).name('kinesis firehose sink')
- plan = eval(self.env.get_execution_plan())
-
- self.assertEqual('kinesis firehose sink: Writer', plan['nodes'][1]['type'])
- self.assertEqual(get_field_value(kinesis_firehose_sink.get_java_function(), 'failOnError'),
- False)
- self.assertEqual(
- get_field_value(kinesis_firehose_sink.get_java_function(), 'deliveryStreamName'),
- 'stream-1')
-
-
-class CassandraSinkTest(PyFlinkStreamingTestCase):
-
- def test_cassandra_sink(self):
- type_info = Types.ROW([Types.STRING(), Types.INT()])
- ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
- type_info=type_info)
- cassandra_sink_builder = CassandraSink.add_sink(ds)
-
- cassandra_sink = cassandra_sink_builder\
- .set_host('localhost', 9876) \
- .set_query('query') \
- .enable_ignore_null_fields() \
- .set_mapper_options(MapperOptions()
- .ttl(1)
- .timestamp(100)
- .tracing(True)
- .if_not_exists(False)
- .consistency_level(ConsistencyLevel.ANY)
- .save_null_fields(True)) \
- .set_max_concurrent_requests(1000) \
- .build()
-
- cassandra_sink.name('cassandra_sink').set_parallelism(3)
-
- plan = eval(self.env.get_execution_plan())
- self.assertEqual("Sink: cassandra_sink", plan['nodes'][1]['type'])
- self.assertEqual(3, plan['nodes'][1]['parallelism'])
diff --git a/flink-python/pyflink/datastream/connectors/tests/test_elasticsearch.py b/flink-python/pyflink/datastream/connectors/tests/test_elasticsearch.py
new file mode 100644
index 00000000000..b3007e995f6
--- /dev/null
+++ b/flink-python/pyflink/datastream/connectors/tests/test_elasticsearch.py
@@ -0,0 +1,133 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from pyflink.common import Types
+from pyflink.datastream.connectors import DeliveryGuarantee
+from pyflink.datastream.connectors.elasticsearch import Elasticsearch7SinkBuilder, \
+ FlushBackoffType, ElasticsearchEmitter
+from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase
+from pyflink.util.java_utils import get_field_value, is_instance_of
+
+
+class FlinkElasticsearch7Test(PyFlinkStreamingTestCase):
+
+ def test_es_sink(self):
+ ds = self.env.from_collection(
+ [{'name': 'ada', 'id': '1'}, {'name': 'luna', 'id': '2'}],
+ type_info=Types.MAP(Types.STRING(), Types.STRING()))
+
+ es_sink = Elasticsearch7SinkBuilder() \
+ .set_emitter(ElasticsearchEmitter.static_index('foo', 'id')) \
+ .set_hosts(['localhost:9200']) \
+ .set_delivery_guarantee(DeliveryGuarantee.AT_LEAST_ONCE) \
+ .set_bulk_flush_max_actions(1) \
+ .set_bulk_flush_max_size_mb(2) \
+ .set_bulk_flush_interval(1000) \
+ .set_bulk_flush_backoff_strategy(FlushBackoffType.CONSTANT, 3, 3000) \
+ .set_connection_username('foo') \
+ .set_connection_password('bar') \
+ .set_connection_path_prefix('foo-bar') \
+ .set_connection_request_timeout(30000) \
+ .set_connection_timeout(31000) \
+ .set_socket_timeout(32000) \
+ .build()
+
+ j_emitter = get_field_value(es_sink.get_java_function(), 'emitter')
+ self.assertTrue(
+ is_instance_of(
+ j_emitter,
+ 'org.apache.flink.connector.elasticsearch.sink.MapElasticsearchEmitter'))
+ self.assertEqual(
+ get_field_value(
+ es_sink.get_java_function(), 'hosts')[0].toString(), 'http://localhost:9200')
+ self.assertEqual(
+ get_field_value(
+ es_sink.get_java_function(), 'deliveryGuarantee').toString(), 'at-least-once')
+
+ j_build_bulk_processor_config = get_field_value(
+ es_sink.get_java_function(), 'buildBulkProcessorConfig')
+ self.assertEqual(j_build_bulk_processor_config.getBulkFlushMaxActions(), 1)
+ self.assertEqual(j_build_bulk_processor_config.getBulkFlushMaxMb(), 2)
+ self.assertEqual(j_build_bulk_processor_config.getBulkFlushInterval(), 1000)
+ self.assertEqual(j_build_bulk_processor_config.getFlushBackoffType().toString(), 'CONSTANT')
+ self.assertEqual(j_build_bulk_processor_config.getBulkFlushBackoffRetries(), 3)
+ self.assertEqual(j_build_bulk_processor_config.getBulkFlushBackOffDelay(), 3000)
+
+ j_network_client_config = get_field_value(
+ es_sink.get_java_function(), 'networkClientConfig')
+ self.assertEqual(j_network_client_config.getUsername(), 'foo')
+ self.assertEqual(j_network_client_config.getPassword(), 'bar')
+ self.assertEqual(j_network_client_config.getConnectionRequestTimeout(), 30000)
+ self.assertEqual(j_network_client_config.getConnectionTimeout(), 31000)
+ self.assertEqual(j_network_client_config.getSocketTimeout(), 32000)
+ self.assertEqual(j_network_client_config.getConnectionPathPrefix(), 'foo-bar')
+
+ ds.sink_to(es_sink).name('es sink')
+
+ def test_es_sink_dynamic(self):
+ ds = self.env.from_collection(
+ [{'name': 'ada', 'id': '1'}, {'name': 'luna', 'id': '2'}],
+ type_info=Types.MAP(Types.STRING(), Types.STRING()))
+
+ es_dynamic_index_sink = Elasticsearch7SinkBuilder() \
+ .set_emitter(ElasticsearchEmitter.dynamic_index('name', 'id')) \
+ .set_hosts(['localhost:9200']) \
+ .build()
+
+ j_emitter = get_field_value(es_dynamic_index_sink.get_java_function(), 'emitter')
+ self.assertTrue(
+ is_instance_of(
+ j_emitter,
+ 'org.apache.flink.connector.elasticsearch.sink.MapElasticsearchEmitter'))
+
+ ds.sink_to(es_dynamic_index_sink).name('es dynamic index sink')
+
+ def test_es_sink_key_none(self):
+ ds = self.env.from_collection(
+ [{'name': 'ada', 'id': '1'}, {'name': 'luna', 'id': '2'}],
+ type_info=Types.MAP(Types.STRING(), Types.STRING()))
+
+ es_sink = Elasticsearch7SinkBuilder() \
+ .set_emitter(ElasticsearchEmitter.static_index('foo')) \
+ .set_hosts(['localhost:9200']) \
+ .build()
+
+ j_emitter = get_field_value(es_sink.get_java_function(), 'emitter')
+ self.assertTrue(
+ is_instance_of(
+ j_emitter,
+ 'org.apache.flink.connector.elasticsearch.sink.MapElasticsearchEmitter'))
+
+ ds.sink_to(es_sink).name('es sink')
+
+ def test_es_sink_dynamic_key_none(self):
+ ds = self.env.from_collection(
+ [{'name': 'ada', 'id': '1'}, {'name': 'luna', 'id': '2'}],
+ type_info=Types.MAP(Types.STRING(), Types.STRING()))
+
+ es_dynamic_index_sink = Elasticsearch7SinkBuilder() \
+ .set_emitter(ElasticsearchEmitter.dynamic_index('name')) \
+ .set_hosts(['localhost:9200']) \
+ .build()
+
+ j_emitter = get_field_value(es_dynamic_index_sink.get_java_function(), 'emitter')
+ self.assertTrue(
+ is_instance_of(
+ j_emitter,
+ 'org.apache.flink.connector.elasticsearch.sink.MapElasticsearchEmitter'))
+
+ ds.sink_to(es_dynamic_index_sink).name('es dynamic index sink')
diff --git a/flink-python/pyflink/datastream/connectors/tests/test_file_system.py b/flink-python/pyflink/datastream/connectors/tests/test_file_system.py
index ca3e6d91a69..b8fa964c6a5 100644
--- a/flink-python/pyflink/datastream/connectors/tests/test_file_system.py
+++ b/flink-python/pyflink/datastream/connectors/tests/test_file_system.py
@@ -15,947 +15,147 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
-import glob
-import os
-import tempfile
-import unittest
-from typing import Tuple, List
-from avro.datafile import DataFileReader
-from avro.io import DatumReader
-from py4j.java_gateway import java_import, JavaObject
+from pyflink.common import Duration
+from pyflink.common.serialization import Encoder
+from pyflink.common.typeinfo import Types
+from pyflink.datastream.connectors.file_system import FileCompactStrategy, FileCompactor, \
+ StreamingFileSink, OutputFileConfig, FileSource, StreamFormat, FileEnumeratorProvider, \
+ FileSplitAssignerProvider, RollingPolicy, FileSink, BucketAssigner
-from pyflink.common import Types, Configuration
-from pyflink.common.watermark_strategy import WatermarkStrategy
-from pyflink.datastream.formats.csv import CsvSchema, CsvReaderFormat, CsvBulkWriter
-from pyflink.datastream.functions import MapFunction
-from pyflink.datastream.connectors.file_system import FileSource, FileSink
-from pyflink.datastream.formats.avro import (
- AvroInputFormat,
- AvroSchema,
- AvroWriters,
- GenericRecordAvroTypeInfo,
-)
-from pyflink.datastream.formats.parquet import AvroParquetReaders, ParquetColumnarRowInputFormat, \
- AvroParquetWriters
-from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
-from pyflink.java_gateway import get_gateway
-from pyflink.table.types import RowType, DataTypes
from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase
-
-
-class FileSourceCsvReaderFormatTests(PyFlinkStreamingTestCase):
-
- def setUp(self):
- super().setUp()
- self.test_sink = DataStreamTestSinkFunction()
- self.csv_file_name = tempfile.mktemp(suffix='.csv', dir=self.tempdir)
-
- def test_csv_primitive_column(self):
- schema, lines = _create_csv_primitive_column_schema_and_lines()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_primitive_column')
- _check_csv_primitive_column_results(self, self.test_sink.get_results(True, False))
-
- def test_csv_add_columns_from(self):
- original_schema, lines = _create_csv_primitive_column_schema_and_lines()
- schema = CsvSchema.builder().add_columns_from(original_schema).build()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_schema_copy')
- _check_csv_primitive_column_results(self, self.test_sink.get_results(True, False))
-
- def test_csv_array_column(self):
- schema, lines = _create_csv_array_column_schema_and_lines()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_array_column')
- _check_csv_array_column_results(self, self.test_sink.get_results(True, False))
-
- def test_csv_allow_comments(self):
- schema, lines = _create_csv_allow_comments_schema_and_lines()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_allow_comments')
- _check_csv_allow_comments_results(self, self.test_sink.get_results(True, False))
-
- def test_csv_use_header(self):
- schema, lines = _create_csv_use_header_schema_and_lines()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_use_header')
- _check_csv_use_header_results(self, self.test_sink.get_results(True, False))
-
- def test_csv_strict_headers(self):
- schema, lines = _create_csv_strict_headers_schema_and_lines()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_strict_headers')
- _check_csv_strict_headers_results(self, self.test_sink.get_results(True, False))
-
- def test_csv_default_quote_char(self):
- schema, lines = _create_csv_default_quote_char_schema_and_lines()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_default_quote_char')
- _check_csv_default_quote_char_results(self, self.test_sink.get_results(True, False))
-
- def test_csv_customize_quote_char(self):
- schema, lines = _create_csv_customize_quote_char_schema_lines()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_customize_quote_char')
- _check_csv_customize_quote_char_results(self, self.test_sink.get_results(True, False))
-
- def test_csv_use_escape_char(self):
- schema, lines = _create_csv_set_escape_char_schema_and_lines()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_use_escape_char')
- _check_csv_set_escape_char_results(self, self.test_sink.get_results(True, False))
-
- def _build_csv_job(self, schema, lines):
- with open(self.csv_file_name, 'w') as f:
- for line in lines:
- f.write(line)
- source = FileSource.for_record_stream_format(
- CsvReaderFormat.for_schema(schema), self.csv_file_name).build()
- ds = self.env.from_source(source, WatermarkStrategy.no_watermarks(), 'csv-source')
- ds.map(PassThroughMapFunction(), output_type=Types.PICKLED_BYTE_ARRAY()) \
- .add_sink(self.test_sink)
-
-
-@unittest.skipIf(os.environ.get('HADOOP_CLASSPATH') is None,
- 'Some Hadoop lib is needed for Parquet Columnar format tests')
-class FileSourceParquetColumnarRowInputFormatTests(PyFlinkStreamingTestCase):
-
- def setUp(self):
- super().setUp()
- self.test_sink = DataStreamTestSinkFunction()
- _import_avro_classes()
-
- def test_parquet_columnar_basic(self):
- parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
- schema, records = _create_basic_avro_schema_and_records()
- FileSourceAvroParquetReadersTests._create_parquet_avro_file(
- parquet_file_name, schema, records)
- row_type = DataTypes.ROW([
- DataTypes.FIELD('null', DataTypes.STRING()), # DataTypes.NULL cannot be serialized
- DataTypes.FIELD('boolean', DataTypes.BOOLEAN()),
- DataTypes.FIELD('int', DataTypes.INT()),
- DataTypes.FIELD('long', DataTypes.BIGINT()),
- DataTypes.FIELD('float', DataTypes.FLOAT()),
- DataTypes.FIELD('double', DataTypes.DOUBLE()),
- DataTypes.FIELD('string', DataTypes.STRING()),
- DataTypes.FIELD('unknown', DataTypes.STRING())
- ])
- self._build_parquet_columnar_job(row_type, parquet_file_name)
- self.env.execute('test_parquet_columnar_basic')
- results = self.test_sink.get_results(True, False)
- _check_basic_avro_schema_results(self, results)
- self.assertIsNone(results[0]['unknown'])
- self.assertIsNone(results[1]['unknown'])
-
- def _build_parquet_columnar_job(self, row_type: RowType, parquet_file_name: str):
- source = FileSource.for_bulk_file_format(
- ParquetColumnarRowInputFormat(Configuration(), row_type, 10, True, True),
- parquet_file_name
- ).build()
- ds = self.env.from_source(source, WatermarkStrategy.no_watermarks(), 'parquet-source')
- ds.map(PassThroughMapFunction()).add_sink(self.test_sink)
-
-
-@unittest.skipIf(os.environ.get('HADOOP_CLASSPATH') is None,
- 'Some Hadoop lib is needed for Parquet-Avro format tests')
-class FileSourceAvroParquetReadersTests(PyFlinkStreamingTestCase):
-
- def setUp(self):
- super().setUp()
- self.test_sink = DataStreamTestSinkFunction()
- _import_avro_classes()
-
- def test_parquet_avro_basic(self):
- parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
- schema, records = _create_basic_avro_schema_and_records()
- self._create_parquet_avro_file(parquet_file_name, schema, records)
- self._build_parquet_avro_job(schema, parquet_file_name)
- self.env.execute("test_parquet_avro_basic")
- results = self.test_sink.get_results(True, False)
- _check_basic_avro_schema_results(self, results)
-
- def test_parquet_avro_enum(self):
- parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
- schema, records = _create_enum_avro_schema_and_records()
- self._create_parquet_avro_file(parquet_file_name, schema, records)
- self._build_parquet_avro_job(schema, parquet_file_name)
- self.env.execute("test_parquet_avro_enum")
- results = self.test_sink.get_results(True, False)
- _check_enum_avro_schema_results(self, results)
-
- def test_parquet_avro_union(self):
- parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
- schema, records = _create_union_avro_schema_and_records()
- self._create_parquet_avro_file(parquet_file_name, schema, records)
- self._build_parquet_avro_job(schema, parquet_file_name)
- self.env.execute("test_parquet_avro_union")
- results = self.test_sink.get_results(True, False)
- _check_union_avro_schema_results(self, results)
-
- def test_parquet_avro_array(self):
- parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
- schema, records = _create_array_avro_schema_and_records()
- self._create_parquet_avro_file(parquet_file_name, schema, records)
- self._build_parquet_avro_job(schema, parquet_file_name)
- self.env.execute("test_parquet_avro_array")
- results = self.test_sink.get_results(True, False)
- _check_array_avro_schema_results(self, results)
-
- def test_parquet_avro_map(self):
- parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
- schema, records = _create_map_avro_schema_and_records()
- self._create_parquet_avro_file(parquet_file_name, schema, records)
- self._build_parquet_avro_job(schema, parquet_file_name)
- self.env.execute("test_parquet_avro_map")
- results = self.test_sink.get_results(True, False)
- _check_map_avro_schema_results(self, results)
-
- def _build_parquet_avro_job(self, record_schema, *parquet_file_name):
- ds = self.env.from_source(
- FileSource.for_record_stream_format(
- AvroParquetReaders.for_generic_record(record_schema),
- *parquet_file_name
- ).build(),
- WatermarkStrategy.for_monotonous_timestamps(),
- "parquet-source"
- )
- ds.map(PassThroughMapFunction()).add_sink(self.test_sink)
-
- @staticmethod
- def _create_parquet_avro_file(file_path: str, schema: AvroSchema, records: list):
- jvm = get_gateway().jvm
- j_path = jvm.org.apache.flink.core.fs.Path(file_path)
- writer = jvm.org.apache.flink.formats.parquet.avro.AvroParquetWriters \
- .forGenericRecord(schema._j_schema) \
- .create(j_path.getFileSystem().create(
- j_path,
- jvm.org.apache.flink.core.fs.FileSystem.WriteMode.OVERWRITE
- ))
- for record in records:
- writer.addElement(record)
- writer.flush()
- writer.finish()
-
-
-class FileSourceAvroInputFormatTests(PyFlinkStreamingTestCase):
-
- def setUp(self):
- super().setUp()
- self.test_sink = DataStreamTestSinkFunction()
- self.avro_file_name = tempfile.mktemp(suffix='.avro', dir=self.tempdir)
- _import_avro_classes()
-
- def test_avro_basic_read(self):
- schema, records = _create_basic_avro_schema_and_records()
- self._create_avro_file(schema, records)
- self._build_avro_job(schema)
- self.env.execute('test_avro_basic_read')
- results = self.test_sink.get_results(True, False)
- _check_basic_avro_schema_results(self, results)
-
- def test_avro_enum_read(self):
- schema, records = _create_enum_avro_schema_and_records()
- self._create_avro_file(schema, records)
- self._build_avro_job(schema)
- self.env.execute('test_avro_enum_read')
- results = self.test_sink.get_results(True, False)
- _check_enum_avro_schema_results(self, results)
-
- def test_avro_union_read(self):
- schema, records = _create_union_avro_schema_and_records()
- self._create_avro_file(schema, records)
- self._build_avro_job(schema)
- self.env.execute('test_avro_union_read')
- results = self.test_sink.get_results(True, False)
- _check_union_avro_schema_results(self, results)
-
- def test_avro_array_read(self):
- schema, records = _create_array_avro_schema_and_records()
- self._create_avro_file(schema, records)
- self._build_avro_job(schema)
- self.env.execute('test_avro_array_read')
- results = self.test_sink.get_results(True, False)
- _check_array_avro_schema_results(self, results)
-
- def test_avro_map_read(self):
- schema, records = _create_map_avro_schema_and_records()
- self._create_avro_file(schema, records)
- self._build_avro_job(schema)
- self.env.execute('test_avro_map_read')
- results = self.test_sink.get_results(True, False)
- _check_map_avro_schema_results(self, results)
-
- def _build_avro_job(self, record_schema):
- ds = self.env.create_input(AvroInputFormat(self.avro_file_name, record_schema))
- ds.map(PassThroughMapFunction()).add_sink(self.test_sink)
-
- def _create_avro_file(self, schema: AvroSchema, records: list):
- jvm = get_gateway().jvm
- j_file = jvm.java.io.File(self.avro_file_name)
- j_datum_writer = jvm.org.apache.flink.avro.shaded.org.apache.avro.generic \
- .GenericDatumWriter()
- j_file_writer = jvm.org.apache.flink.avro.shaded.org.apache.avro.file \
- .DataFileWriter(j_datum_writer)
- j_file_writer.create(schema._j_schema, j_file)
- for r in records:
- j_file_writer.append(r)
- j_file_writer.close()
-
-
-@unittest.skipIf(os.environ.get('HADOOP_CLASSPATH') is None,
- 'Some Hadoop lib is needed for Parquet-Avro format tests')
-class FileSinkAvroParquetWritersTests(PyFlinkStreamingTestCase):
-
- def setUp(self):
- super().setUp()
- # NOTE: parallelism == 1 is required to keep the order of results
- self.env.set_parallelism(1)
- self.parquet_dir_name = tempfile.mktemp(dir=self.tempdir)
- self.test_sink = DataStreamTestSinkFunction()
-
- def test_parquet_avro_basic_write(self):
- schema, objects = _create_basic_avro_schema_and_py_objects()
- self._build_avro_parquet_job(schema, objects)
- self.env.execute('test_parquet_avro_basic_write')
- results = self._read_parquet_avro_file(schema)
- _check_basic_avro_schema_results(self, results)
-
- def test_parquet_avro_enum_write(self):
- schema, objects = _create_enum_avro_schema_and_py_objects()
- self._build_avro_parquet_job(schema, objects)
- self.env.execute('test_parquet_avro_enum_write')
- results = self._read_parquet_avro_file(schema)
- _check_enum_avro_schema_results(self, results)
-
- def test_parquet_avro_union_write(self):
- schema, objects = _create_union_avro_schema_and_py_objects()
- self._build_avro_parquet_job(schema, objects)
- self.env.execute('test_parquet_avro_union_write')
- results = self._read_parquet_avro_file(schema)
- _check_union_avro_schema_results(self, results)
-
- def test_parquet_avro_array_write(self):
- schema, objects = _create_array_avro_schema_and_py_objects()
- self._build_avro_parquet_job(schema, objects)
- self.env.execute('test_parquet_avro_array_write')
- results = self._read_parquet_avro_file(schema)
- _check_array_avro_schema_results(self, results)
-
- def test_parquet_avro_map_write(self):
- schema, objects = _create_map_avro_schema_and_py_objects()
- self._build_avro_parquet_job(schema, objects)
- self.env.execute('test_parquet_avro_map_write')
- results = self._read_parquet_avro_file(schema)
- _check_map_avro_schema_results(self, results)
-
- def _build_avro_parquet_job(self, schema, objects):
- ds = self.env.from_collection(objects)
- avro_type_info = GenericRecordAvroTypeInfo(schema)
- sink = FileSink.for_bulk_format(
- self.parquet_dir_name, AvroParquetWriters.for_generic_record(schema)
- ).build()
- ds.map(lambda e: e, output_type=avro_type_info).sink_to(sink)
-
- def _read_parquet_avro_file(self, schema) -> List[dict]:
- parquet_files = [f for f in glob.glob(self.parquet_dir_name, recursive=True)]
- FileSourceAvroParquetReadersTests._build_parquet_avro_job(self, schema, *parquet_files)
- self.env.execute()
- return self.test_sink.get_results(True, False)
-
-
-class FileSinkAvroWritersTests(PyFlinkStreamingTestCase):
-
- def setUp(self):
- super().setUp()
- # NOTE: parallelism == 1 is required to keep the order of results
- self.env.set_parallelism(1)
- self.avro_dir_name = tempfile.mkdtemp(dir=self.tempdir)
-
- def test_avro_basic_write(self):
- schema, objects = _create_basic_avro_schema_and_py_objects()
- self._build_avro_job(schema, objects)
- self.env.execute('test_avro_basic_write')
- results = self._read_avro_file()
- _check_basic_avro_schema_results(self, results)
-
- def test_avro_enum_write(self):
- schema, objects = _create_enum_avro_schema_and_py_objects()
- self._build_avro_job(schema, objects)
- self.env.execute('test_avro_enum_write')
- results = self._read_avro_file()
- _check_enum_avro_schema_results(self, results)
-
- def test_avro_union_write(self):
- schema, objects = _create_union_avro_schema_and_py_objects()
- self._build_avro_job(schema, objects)
- self.env.execute('test_avro_union_write')
- results = self._read_avro_file()
- _check_union_avro_schema_results(self, results)
-
- def test_avro_array_write(self):
- schema, objects = _create_array_avro_schema_and_py_objects()
- self._build_avro_job(schema, objects)
- self.env.execute('test_avro_array_write')
- results = self._read_avro_file()
- _check_array_avro_schema_results(self, results)
-
- def test_avro_map_write(self):
- schema, objects = _create_map_avro_schema_and_py_objects()
- self._build_avro_job(schema, objects)
- self.env.execute('test_avro_map_write')
- results = self._read_avro_file()
- _check_map_avro_schema_results(self, results)
-
- def _build_avro_job(self, schema, objects):
- ds = self.env.from_collection(objects)
- sink = FileSink.for_bulk_format(
- self.avro_dir_name, AvroWriters.for_generic_record(schema)
- ).build()
- ds.map(lambda e: e, output_type=GenericRecordAvroTypeInfo(schema)).sink_to(sink)
-
- def _read_avro_file(self) -> List[dict]:
- records = []
- for file in glob.glob(os.path.join(os.path.join(self.avro_dir_name, '**/*'))):
- for record in DataFileReader(open(file, 'rb'), DatumReader()):
- records.append(record)
- return records
-
-
-class FileSinkCsvBulkWriterTests(PyFlinkStreamingTestCase):
-
- def setUp(self):
- super().setUp()
- self.env.set_parallelism(1)
- self.csv_file_name = tempfile.mktemp(dir=self.tempdir)
- self.csv_dir_name = tempfile.mkdtemp(dir=self.tempdir)
-
- def test_csv_primitive_column_write(self):
- schema, lines = _create_csv_primitive_column_schema_and_lines()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_primitive_column_write')
- results = self._read_csv_file()
- self.assertTrue(len(results) == 1)
- self.assertEqual(
- results[0],
- '127,-32767,2147483647,-9223372036854775808,3.0E38,2.0E-308,2,true,string\n'
+from pyflink.util.java_utils import load_java_class
+
+
+class FileSystemTests(PyFlinkStreamingTestCase):
+
+ def test_stream_file_sink(self):
+ self.env.set_parallelism(2)
+ ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
+ type_info=Types.ROW([Types.STRING(), Types.INT()]))
+ ds.map(
+ lambda a: a[0],
+ Types.STRING()).add_sink(
+ StreamingFileSink.for_row_format(self.tempdir, Encoder.simple_string_encoder())
+ .with_rolling_policy(
+ RollingPolicy.default_rolling_policy(
+ part_size=1024 * 1024 * 1024,
+ rollover_interval=15 * 60 * 1000,
+ inactivity_interval=5 * 60 * 1000))
+ .with_output_file_config(
+ OutputFileConfig.OutputFileConfigBuilder()
+ .with_part_prefix("prefix")
+ .with_part_suffix("suffix").build()).build())
+
+ self.env.execute("test_streaming_file_sink")
+
+ results = []
+ import os
+ for root, dirs, files in os.walk(self.tempdir, topdown=True):
+ for file in files:
+ self.assertTrue(file.startswith('.prefix'))
+ self.assertTrue('suffix' in file)
+ path = root + "/" + file
+ with open(path) as infile:
+ for line in infile:
+ results.append(line)
+
+ expected = ['deeefg\n', 'bdc\n', 'ab\n', 'cfgs\n']
+ results.sort()
+ expected.sort()
+ self.assertEqual(expected, results)
+
+ def test_file_source(self):
+ stream_format = StreamFormat.text_line_format()
+ paths = ["/tmp/1.txt", "/tmp/2.txt"]
+ file_source_builder = FileSource.for_record_stream_format(stream_format, *paths)
+ file_source = file_source_builder\
+ .monitor_continuously(Duration.of_days(1)) \
+ .set_file_enumerator(FileEnumeratorProvider.default_splittable_file_enumerator()) \
+ .set_split_assigner(FileSplitAssignerProvider.locality_aware_split_assigner()) \
+ .build()
+
+ continuous_setting = file_source.get_java_function().getContinuousEnumerationSettings()
+ self.assertIsNotNone(continuous_setting)
+ self.assertEqual(Duration.of_days(1), Duration(continuous_setting.getDiscoveryInterval()))
+
+ input_paths_field = \
+ load_java_class("org.apache.flink.connector.file.src.AbstractFileSource"). \
+ getDeclaredField("inputPaths")
+ input_paths_field.setAccessible(True)
+ input_paths = input_paths_field.get(file_source.get_java_function())
+ self.assertEqual(len(input_paths), len(paths))
+ self.assertEqual(str(input_paths[0]), paths[0])
+ self.assertEqual(str(input_paths[1]), paths[1])
+
+ def test_file_sink(self):
+ base_path = "/tmp/1.txt"
+ encoder = Encoder.simple_string_encoder()
+ file_sink_builder = FileSink.for_row_format(base_path, encoder)
+ file_sink = file_sink_builder\
+ .with_bucket_check_interval(1000) \
+ .with_bucket_assigner(BucketAssigner.base_path_bucket_assigner()) \
+ .with_rolling_policy(RollingPolicy.on_checkpoint_rolling_policy()) \
+ .with_output_file_config(
+ OutputFileConfig.builder().with_part_prefix("pre").with_part_suffix("suf").build())\
+ .enable_compact(FileCompactStrategy.builder()
+ .enable_compaction_on_checkpoint(3)
+ .set_size_threshold(1024)
+ .set_num_compact_threads(2)
+ .build(),
+ FileCompactor.concat_file_compactor(b'\n')) \
+ .build()
+
+ buckets_builder_field = \
+ load_java_class("org.apache.flink.connector.file.sink.FileSink"). \
+ getDeclaredField("bucketsBuilder")
+ buckets_builder_field.setAccessible(True)
+ buckets_builder = buckets_builder_field.get(file_sink.get_java_function())
+
+ self.assertEqual("DefaultRowFormatBuilder", buckets_builder.getClass().getSimpleName())
+
+ row_format_builder_clz = load_java_class(
+ "org.apache.flink.connector.file.sink.FileSink$RowFormatBuilder")
+ encoder_field = row_format_builder_clz.getDeclaredField("encoder")
+ encoder_field.setAccessible(True)
+ self.assertEqual("SimpleStringEncoder",
+ encoder_field.get(buckets_builder).getClass().getSimpleName())
+
+ interval_field = row_format_builder_clz.getDeclaredField("bucketCheckInterval")
+ interval_field.setAccessible(True)
+ self.assertEqual(1000, interval_field.get(buckets_builder))
+
+ bucket_assigner_field = row_format_builder_clz.getDeclaredField("bucketAssigner")
+ bucket_assigner_field.setAccessible(True)
+ self.assertEqual("BasePathBucketAssigner",
+ bucket_assigner_field.get(buckets_builder).getClass().getSimpleName())
+
+ rolling_policy_field = row_format_builder_clz.getDeclaredField("rollingPolicy")
+ rolling_policy_field.setAccessible(True)
+ self.assertEqual("OnCheckpointRollingPolicy",
+ rolling_policy_field.get(buckets_builder).getClass().getSimpleName())
+
+ output_file_config_field = row_format_builder_clz.getDeclaredField("outputFileConfig")
+ output_file_config_field.setAccessible(True)
+ output_file_config = output_file_config_field.get(buckets_builder)
+ self.assertEqual("pre", output_file_config.getPartPrefix())
+ self.assertEqual("suf", output_file_config.getPartSuffix())
+
+ compact_strategy_field = row_format_builder_clz.getDeclaredField("compactStrategy")
+ compact_strategy_field.setAccessible(True)
+ compact_strategy = compact_strategy_field.get(buckets_builder)
+ self.assertEqual(3, compact_strategy.getNumCheckpointsBeforeCompaction())
+ self.assertEqual(1024, compact_strategy.getSizeThreshold())
+ self.assertEqual(2, compact_strategy.getNumCompactThreads())
+
+ file_compactor_field = row_format_builder_clz.getDeclaredField("fileCompactor")
+ file_compactor_field.setAccessible(True)
+ file_compactor = file_compactor_field.get(buckets_builder)
+ self.assertEqual("ConcatFileCompactor", file_compactor.getClass().getSimpleName())
+ concat_file_compactor_clz = load_java_class(
+ "org.apache.flink.connector.file.sink.compactor.ConcatFileCompactor"
)
-
- def test_csv_array_column_write(self):
- schema, lines = _create_csv_array_column_schema_and_lines()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_array_column_write')
- results = self._read_csv_file()
- self.assertTrue(len(results) == 1)
- self.assertListEqual(results, lines)
-
- def test_csv_default_quote_char_write(self):
- schema, lines = _create_csv_default_quote_char_schema_and_lines()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_default_quote_char_write')
- results = self._read_csv_file()
- self.assertTrue(len(results) == 1)
- self.assertListEqual(results, lines)
-
- def test_csv_customize_quote_char_write(self):
- schema, lines = _create_csv_customize_quote_char_schema_lines()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_customize_quote_char_write')
- results = self._read_csv_file()
- self.assertTrue(len(results) == 1)
- self.assertListEqual(results, lines)
-
- def test_csv_use_escape_char_write(self):
- schema, lines = _create_csv_set_escape_char_schema_and_lines()
- self._build_csv_job(schema, lines)
- self.env.execute('test_csv_use_escape_char_write')
- results = self._read_csv_file()
- self.assertTrue(len(results) == 1)
- self.assertListEqual(results, ['"string,","""string2"""\n'])
-
- def _build_csv_job(self, schema: CsvSchema, lines):
- with open(self.csv_file_name, 'w') as f:
- for line in lines:
- f.write(line)
- source = FileSource.for_record_stream_format(
- CsvReaderFormat.for_schema(schema), self.csv_file_name
- ).build()
- ds = self.env.from_source(source, WatermarkStrategy.no_watermarks(), 'csv-source')
- sink = FileSink.for_bulk_format(
- self.csv_dir_name, CsvBulkWriter.for_schema(schema)
- ).build()
- ds.map(lambda e: e, output_type=schema.get_type_info()).sink_to(sink)
-
- def _read_csv_file(self) -> List[str]:
- lines = []
- for file in glob.glob(os.path.join(self.csv_dir_name, '**/*')):
- with open(file, 'r') as f:
- lines.extend(f.readlines())
- return lines
-
-
-class PassThroughMapFunction(MapFunction):
-
- def map(self, value):
- return value
-
-
-def _import_avro_classes():
- jvm = get_gateway().jvm
- classes = ['org.apache.avro.generic.GenericData']
- prefix = 'org.apache.flink.avro.shaded.'
- for cls in classes:
- java_import(jvm, prefix + cls)
-
-
-def _create_csv_primitive_column_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
- schema = CsvSchema.builder() \
- .add_number_column('tinyint', DataTypes.TINYINT()) \
- .add_number_column('smallint', DataTypes.SMALLINT()) \
- .add_number_column('int', DataTypes.INT()) \
- .add_number_column('bigint', DataTypes.BIGINT()) \
- .add_number_column('float', DataTypes.FLOAT()) \
- .add_number_column('double', DataTypes.DOUBLE()) \
- .add_number_column('decimal', DataTypes.DECIMAL(2, 0)) \
- .add_boolean_column('boolean') \
- .add_string_column('string') \
- .build()
- lines = [
- '127,'
- '-32767,'
- '2147483647,'
- '-9223372036854775808,'
- '3e38,'
- '2e-308,'
- '1.5,'
- 'true,'
- 'string\n',
- ]
- return schema, lines
-
-
-def _check_csv_primitive_column_results(test, results):
- row = results[0]
- test.assertEqual(row['tinyint'], 127)
- test.assertEqual(row['smallint'], -32767)
- test.assertEqual(row['int'], 2147483647)
- test.assertEqual(row['bigint'], -9223372036854775808)
- test.assertAlmostEqual(row['float'], 3e38, delta=1e31)
- test.assertAlmostEqual(row['double'], 2e-308, delta=2e-301)
- test.assertAlmostEqual(row['decimal'], 2)
- test.assertEqual(row['boolean'], True)
- test.assertEqual(row['string'], 'string')
-
-
-def _create_csv_array_column_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
- schema = CsvSchema.builder() \
- .add_array_column('number_array', separator=';', element_type=DataTypes.INT()) \
- .add_array_column('boolean_array', separator=':', element_type=DataTypes.BOOLEAN()) \
- .add_array_column('string_array', separator=',', element_type=DataTypes.STRING()) \
- .set_column_separator('|') \
- .disable_quote_char() \
- .build()
- lines = [
- '1;2;3|'
- 'true:false|'
- 'a,b,c\n',
- ]
- return schema, lines
-
-
-def _check_csv_array_column_results(test, results):
- row = results[0]
- test.assertListEqual(row['number_array'], [1, 2, 3])
- test.assertListEqual(row['boolean_array'], [True, False])
- test.assertListEqual(row['string_array'], ['a', 'b', 'c'])
-
-
-def _create_csv_allow_comments_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
- schema = CsvSchema.builder() \
- .add_string_column('string') \
- .set_allow_comments() \
- .build()
- lines = [
- 'a\n',
- '# this is comment\n',
- 'b\n',
- ]
- return schema, lines
-
-
-def _check_csv_allow_comments_results(test, results):
- test.assertEqual(results[0]['string'], 'a')
- test.assertEqual(results[1]['string'], 'b')
-
-
-def _create_csv_use_header_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
- schema = CsvSchema.builder() \
- .add_string_column('string') \
- .add_number_column('number') \
- .set_use_header() \
- .build()
- lines = [
- 'h1,h2\n',
- 'string,123\n',
- ]
- return schema, lines
-
-
-def _check_csv_use_header_results(test, results):
- row = results[0]
- test.assertEqual(row['string'], 'string')
- test.assertEqual(row['number'], 123)
-
-
-def _create_csv_strict_headers_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
- schema = CsvSchema.builder() \
- .add_string_column('string') \
- .add_number_column('number') \
- .set_use_header() \
- .set_strict_headers() \
- .build()
- lines = [
- 'string,number\n',
- 'string,123\n',
- ]
- return schema, lines
-
-
-def _check_csv_strict_headers_results(test, results):
- row = results[0]
- test.assertEqual(row['string'], 'string')
- test.assertEqual(row['number'], 123)
-
-
-def _create_csv_default_quote_char_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
- schema = CsvSchema.builder() \
- .add_string_column('string') \
- .add_string_column('string2') \
- .set_column_separator('|') \
- .build()
- lines = [
- '"string"|"string2"\n',
- ]
- return schema, lines
-
-
-def _check_csv_default_quote_char_results(test, results):
- row = results[0]
- test.assertEqual(row['string'], 'string')
-
-
-def _create_csv_customize_quote_char_schema_lines() -> Tuple[CsvSchema, List[str]]:
- schema = CsvSchema.builder() \
- .add_string_column('string') \
- .add_string_column('string2') \
- .set_column_separator('|') \
- .set_quote_char('`') \
- .build()
- lines = [
- '`string`|`string2`\n',
- ]
- return schema, lines
-
-
-def _check_csv_customize_quote_char_results(test, results):
- row = results[0]
- test.assertEqual(row['string'], 'string')
-
-
-def _create_csv_set_escape_char_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
- schema = CsvSchema.builder() \
- .add_string_column('string') \
- .add_string_column('string2') \
- .set_column_separator(',') \
- .set_escape_char('\\') \
- .build()
- lines = [
- 'string\\,,\\"string2\\"\n',
- ]
- return schema, lines
-
-
-def _check_csv_set_escape_char_results(test, results):
- row = results[0]
- test.assertEqual(row['string'], 'string,')
- test.assertEqual(row['string2'], '"string2"')
-
-
-BASIC_SCHEMA = """
-{
- "type": "record",
- "name": "test",
- "fields": [
- { "name": "null", "type": "null" },
- { "name": "boolean", "type": "boolean" },
- { "name": "int", "type": "int" },
- { "name": "long", "type": "long" },
- { "name": "float", "type": "float" },
- { "name": "double", "type": "double" },
- { "name": "string", "type": "string" }
- ]
-}
-"""
-
-
-def _create_basic_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
- schema = AvroSchema.parse_string(BASIC_SCHEMA)
- records = [_create_basic_avro_record(schema, True, 0, 1, 2, 3, 's1'),
- _create_basic_avro_record(schema, False, 4, 5, 6, 7, 's2')]
- return schema, records
-
-
-def _create_basic_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
- schema = AvroSchema.parse_string(BASIC_SCHEMA)
- objects = [
- {'null': None, 'boolean': True, 'int': 0, 'long': 1,
- 'float': 2., 'double': 3., 'string': 's1'},
- {'null': None, 'boolean': False, 'int': 4, 'long': 5,
- 'float': 6., 'double': 7., 'string': 's2'},
- ]
- return schema, objects
-
-
-def _check_basic_avro_schema_results(test, results):
- result1 = results[0]
- result2 = results[1]
- test.assertEqual(result1['null'], None)
- test.assertEqual(result1['boolean'], True)
- test.assertEqual(result1['int'], 0)
- test.assertEqual(result1['long'], 1)
- test.assertAlmostEqual(result1['float'], 2, delta=1e-3)
- test.assertAlmostEqual(result1['double'], 3, delta=1e-3)
- test.assertEqual(result1['string'], 's1')
- test.assertEqual(result2['null'], None)
- test.assertEqual(result2['boolean'], False)
- test.assertEqual(result2['int'], 4)
- test.assertEqual(result2['long'], 5)
- test.assertAlmostEqual(result2['float'], 6, delta=1e-3)
- test.assertAlmostEqual(result2['double'], 7, delta=1e-3)
- test.assertEqual(result2['string'], 's2')
-
-
-ENUM_SCHEMA = """
-{
- "type": "record",
- "name": "test",
- "fields": [
- {
- "name": "suit",
- "type": {
- "type": "enum",
- "name": "Suit",
- "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]
- }
- }
- ]
-}
-"""
-
-
-def _create_enum_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
- schema = AvroSchema.parse_string(ENUM_SCHEMA)
- records = [_create_enum_avro_record(schema, 'SPADES'),
- _create_enum_avro_record(schema, 'DIAMONDS')]
- return schema, records
-
-
-def _create_enum_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
- schema = AvroSchema.parse_string(ENUM_SCHEMA)
- records = [
- {'suit': 'SPADES'},
- {'suit': 'DIAMONDS'},
- ]
- return schema, records
-
-
-def _check_enum_avro_schema_results(test, results):
- test.assertEqual(results[0]['suit'], 'SPADES')
- test.assertEqual(results[1]['suit'], 'DIAMONDS')
-
-
-UNION_SCHEMA = """
-{
- "type": "record",
- "name": "test",
- "fields": [
- {
- "name": "union",
- "type": [ "int", "double", "null" ]
- }
- ]
-}
-"""
-
-
-def _create_union_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
- schema = AvroSchema.parse_string(UNION_SCHEMA)
- records = [_create_union_avro_record(schema, 1),
- _create_union_avro_record(schema, 2.),
- _create_union_avro_record(schema, None)]
- return schema, records
-
-
-def _create_union_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
- schema = AvroSchema.parse_string(UNION_SCHEMA)
- records = [
- {'union': 1},
- {'union': 2.},
- {'union': None},
- ]
- return schema, records
-
-
-def _check_union_avro_schema_results(test, results):
- test.assertEqual(results[0]['union'], 1)
- test.assertAlmostEqual(results[1]['union'], 2.0, delta=1e-3)
- test.assertEqual(results[2]['union'], None)
-
-
-# It seems there's bug when array item record contains only one field, which throws
-# java.lang.ClassCastException: required ... is not a group when reading
-ARRAY_SCHEMA = """
-{
- "type": "record",
- "name": "test",
- "fields": [
- {
- "name": "array",
- "type": {
- "type": "array",
- "items": {
- "type": "record",
- "name": "item",
- "fields": [
- { "name": "int", "type": "int" },
- { "name": "double", "type": "double" }
- ]
- }
- }
- }
- ]
-}
-"""
-
-
-def _create_array_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
- schema = AvroSchema.parse_string(ARRAY_SCHEMA)
- records = [_create_array_avro_record(schema, [(1, 2.), (3, 4.)]),
- _create_array_avro_record(schema, [(5, 6.), (7, 8.)])]
- return schema, records
-
-
-def _create_array_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
- schema = AvroSchema.parse_string(ARRAY_SCHEMA)
- records = [
- {'array': [{'int': 1, 'double': 2.}, {'int': 3, 'double': 4.}]},
- {'array': [{'int': 5, 'double': 6.}, {'int': 7, 'double': 8.}]},
- ]
- return schema, records
-
-
-def _check_array_avro_schema_results(test, results):
- result1 = results[0]
- result2 = results[1]
- test.assertEqual(result1['array'][0]['int'], 1)
- test.assertAlmostEqual(result1['array'][0]['double'], 2., delta=1e-3)
- test.assertEqual(result1['array'][1]['int'], 3)
- test.assertAlmostEqual(result1['array'][1]['double'], 4., delta=1e-3)
- test.assertEqual(result2['array'][0]['int'], 5)
- test.assertAlmostEqual(result2['array'][0]['double'], 6., delta=1e-3)
- test.assertEqual(result2['array'][1]['int'], 7)
- test.assertAlmostEqual(result2['array'][1]['double'], 8., delta=1e-3)
-
-
-MAP_SCHEMA = """
-{
- "type": "record",
- "name": "test",
- "fields": [
- {
- "name": "map",
- "type": {
- "type": "map",
- "values": "long"
- }
- }
- ]
-}
-"""
-
-
-def _create_map_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
- schema = AvroSchema.parse_string(MAP_SCHEMA)
- records = [_create_map_avro_record(schema, {'a': 1, 'b': 2}),
- _create_map_avro_record(schema, {'c': 3, 'd': 4})]
- return schema, records
-
-
-def _create_map_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
- schema = AvroSchema.parse_string(MAP_SCHEMA)
- records = [
- {'map': {'a': 1, 'b': 2}},
- {'map': {'c': 3, 'd': 4}},
- ]
- return schema, records
-
-
-def _check_map_avro_schema_results(test, results):
- result1 = results[0]
- result2 = results[1]
- test.assertEqual(result1['map']['a'], 1)
- test.assertEqual(result1['map']['b'], 2)
- test.assertEqual(result2['map']['c'], 3)
- test.assertEqual(result2['map']['d'], 4)
-
-
-def _create_basic_avro_record(schema: AvroSchema, boolean_value, int_value, long_value,
- float_value, double_value, string_value):
- jvm = get_gateway().jvm
- j_record = jvm.GenericData.Record(schema._j_schema)
- j_record.put('boolean', boolean_value)
- j_record.put('int', int_value)
- j_record.put('long', long_value)
- j_record.put('float', float_value)
- j_record.put('double', double_value)
- j_record.put('string', string_value)
- return j_record
-
-
-def _create_enum_avro_record(schema: AvroSchema, enum_value):
- jvm = get_gateway().jvm
- j_record = jvm.GenericData.Record(schema._j_schema)
- j_enum = jvm.GenericData.EnumSymbol(schema._j_schema.getField('suit').schema(), enum_value)
- j_record.put('suit', j_enum)
- return j_record
-
-
-def _create_union_avro_record(schema, union_value):
- jvm = get_gateway().jvm
- j_record = jvm.GenericData.Record(schema._j_schema)
- j_record.put('union', union_value)
- return j_record
-
-
-def _create_array_avro_record(schema, item_values: list):
- jvm = get_gateway().jvm
- j_record = jvm.GenericData.Record(schema._j_schema)
- item_schema = AvroSchema(schema._j_schema.getField('array').schema().getElementType())
- j_array = jvm.java.util.ArrayList()
- for idx, item_value in enumerate(item_values):
- j_item = jvm.GenericData.Record(item_schema._j_schema)
- j_item.put('int', item_value[0])
- j_item.put('double', item_value[1])
- j_array.add(j_item)
- j_record.put('array', j_array)
- return j_record
-
-
-def _create_map_avro_record(schema, map: dict):
- jvm = get_gateway().jvm
- j_record = jvm.GenericData.Record(schema._j_schema)
- j_map = jvm.java.util.HashMap()
- for k, v in map.items():
- j_map.put(k, v)
- j_record.put('map', j_map)
- return j_record
+ file_delimiter_field = concat_file_compactor_clz.getDeclaredField("fileDelimiter")
+ file_delimiter_field.setAccessible(True)
+ file_delimiter = file_delimiter_field.get(file_compactor)
+ self.assertEqual(b'\n', file_delimiter)
diff --git a/flink-python/pyflink/datastream/connectors/tests/test_jdbc.py b/flink-python/pyflink/datastream/connectors/tests/test_jdbc.py
new file mode 100644
index 00000000000..0912cb0fc55
--- /dev/null
+++ b/flink-python/pyflink/datastream/connectors/tests/test_jdbc.py
@@ -0,0 +1,61 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from pyflink.common import Types
+from pyflink.datastream.connectors.jdbc import JdbcSink, JdbcConnectionOptions, JdbcExecutionOptions
+from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase
+from pyflink.util.java_utils import get_field_value
+
+
+class FlinkJdbcSinkTest(PyFlinkStreamingTestCase):
+
+ def test_jdbc_sink(self):
+ ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
+ type_info=Types.ROW([Types.STRING(), Types.INT()]))
+ jdbc_connection_options = JdbcConnectionOptions.JdbcConnectionOptionsBuilder()\
+ .with_driver_name('com.mysql.jdbc.Driver')\
+ .with_user_name('root')\
+ .with_password('password')\
+ .with_url('jdbc:mysql://server-name:server-port/database-name').build()
+
+ jdbc_execution_options = JdbcExecutionOptions.builder().with_batch_interval_ms(2000)\
+ .with_batch_size(100).with_max_retries(5).build()
+ jdbc_sink = JdbcSink.sink("insert into test table", ds.get_type(), jdbc_connection_options,
+ jdbc_execution_options)
+
+ ds.add_sink(jdbc_sink).name('jdbc sink')
+ plan = eval(self.env.get_execution_plan())
+ self.assertEqual('Sink: jdbc sink', plan['nodes'][1]['type'])
+ j_output_format = get_field_value(jdbc_sink.get_java_function(), 'outputFormat')
+
+ connection_options = JdbcConnectionOptions(
+ get_field_value(get_field_value(j_output_format, 'connectionProvider'),
+ 'jdbcOptions'))
+ self.assertEqual(jdbc_connection_options.get_db_url(), connection_options.get_db_url())
+ self.assertEqual(jdbc_connection_options.get_driver_name(),
+ connection_options.get_driver_name())
+ self.assertEqual(jdbc_connection_options.get_password(), connection_options.get_password())
+ self.assertEqual(jdbc_connection_options.get_user_name(),
+ connection_options.get_user_name())
+
+ exec_options = JdbcExecutionOptions(get_field_value(j_output_format, 'executionOptions'))
+ self.assertEqual(jdbc_execution_options.get_batch_interval_ms(),
+ exec_options.get_batch_interval_ms())
+ self.assertEqual(jdbc_execution_options.get_batch_size(),
+ exec_options.get_batch_size())
+ self.assertEqual(jdbc_execution_options.get_max_retries(),
+ exec_options.get_max_retries())
diff --git a/flink-python/pyflink/datastream/connectors/tests/test_kafka.py b/flink-python/pyflink/datastream/connectors/tests/test_kafka.py
index de02b041333..725932ed7d5 100644
--- a/flink-python/pyflink/datastream/connectors/tests/test_kafka.py
+++ b/flink-python/pyflink/datastream/connectors/tests/test_kafka.py
@@ -19,6 +19,7 @@ import json
from typing import Dict
import pyflink.datastream.data_stream as data_stream
+from pyflink.common import typeinfo
from pyflink.common.configuration import Configuration
from pyflink.common.serialization import SimpleStringSchema, DeserializationSchema, \
@@ -29,14 +30,62 @@ from pyflink.common.types import Row, to_java_data_structure
from pyflink.common.watermark_strategy import WatermarkStrategy
from pyflink.datastream.connectors.base import DeliveryGuarantee
from pyflink.datastream.connectors.kafka import KafkaSource, KafkaTopicPartition, \
- KafkaOffsetsInitializer, KafkaOffsetResetStrategy, KafkaRecordSerializationSchema, KafkaSink
+ KafkaOffsetsInitializer, KafkaOffsetResetStrategy, KafkaRecordSerializationSchema, KafkaSink, \
+ FlinkKafkaProducer, FlinkKafkaConsumer
from pyflink.java_gateway import get_gateway
-from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase, PyFlinkTestCase
+from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase, PyFlinkTestCase, \
+ invoke_java_object_method
from pyflink.util.java_utils import to_jarray, is_instance_of, get_field_value
class KafkaSourceTests(PyFlinkStreamingTestCase):
+ def test_legacy_kafka_connector(self):
+ source_topic = 'test_source_topic'
+ sink_topic = 'test_sink_topic'
+ props = {'bootstrap.servers': 'localhost:9092', 'group.id': 'test_group'}
+ type_info = Types.ROW([Types.INT(), Types.STRING()])
+
+ # Test for kafka consumer
+ deserialization_schema = JsonRowDeserializationSchema.builder() \
+ .type_info(type_info=type_info).build()
+
+ flink_kafka_consumer = FlinkKafkaConsumer(source_topic, deserialization_schema, props)
+ flink_kafka_consumer.set_start_from_earliest()
+ flink_kafka_consumer.set_commit_offsets_on_checkpoints(True)
+
+ j_properties = get_field_value(flink_kafka_consumer.get_java_function(), 'properties')
+ self.assertEqual('localhost:9092', j_properties.getProperty('bootstrap.servers'))
+ self.assertEqual('test_group', j_properties.getProperty('group.id'))
+ self.assertTrue(get_field_value(flink_kafka_consumer.get_java_function(),
+ 'enableCommitOnCheckpoints'))
+ j_start_up_mode = get_field_value(flink_kafka_consumer.get_java_function(), 'startupMode')
+
+ j_deserializer = get_field_value(flink_kafka_consumer.get_java_function(), 'deserializer')
+ j_deserialize_type_info = invoke_java_object_method(j_deserializer, "getProducedType")
+ deserialize_type_info = typeinfo._from_java_type(j_deserialize_type_info)
+ self.assertTrue(deserialize_type_info == type_info)
+ self.assertTrue(j_start_up_mode.equals(get_gateway().jvm
+ .org.apache.flink.streaming.connectors
+ .kafka.config.StartupMode.EARLIEST))
+ j_topic_desc = get_field_value(flink_kafka_consumer.get_java_function(),
+ 'topicsDescriptor')
+ j_topics = invoke_java_object_method(j_topic_desc, 'getFixedTopics')
+ self.assertEqual(['test_source_topic'], list(j_topics))
+
+ # Test for kafka producer
+ serialization_schema = JsonRowSerializationSchema.builder().with_type_info(type_info) \
+ .build()
+ flink_kafka_producer = FlinkKafkaProducer(sink_topic, serialization_schema, props)
+ flink_kafka_producer.set_write_timestamp_to_kafka(False)
+
+ j_producer_config = get_field_value(flink_kafka_producer.get_java_function(),
+ 'producerConfig')
+ self.assertEqual('localhost:9092', j_producer_config.getProperty('bootstrap.servers'))
+ self.assertEqual('test_group', j_producer_config.getProperty('group.id'))
+ self.assertFalse(get_field_value(flink_kafka_producer.get_java_function(),
+ 'writeTimestampToKafka'))
+
def test_compiling(self):
source = KafkaSource.builder() \
.set_bootstrap_servers('localhost:9092') \
diff --git a/flink-python/pyflink/datastream/connectors/tests/test_kinesis.py b/flink-python/pyflink/datastream/connectors/tests/test_kinesis.py
new file mode 100644
index 00000000000..d96f14b3a9c
--- /dev/null
+++ b/flink-python/pyflink/datastream/connectors/tests/test_kinesis.py
@@ -0,0 +1,108 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from pyflink.common import SimpleStringSchema, Types
+from pyflink.datastream.connectors.kinesis import PartitionKeyGenerator, FlinkKinesisConsumer, \
+ KinesisStreamsSink, KinesisFirehoseSink
+from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase
+from pyflink.util.java_utils import get_field_value
+
+
+class FlinkKinesisTest(PyFlinkStreamingTestCase):
+
+ def test_kinesis_source(self):
+ consumer_config = {
+ 'aws.region': 'us-east-1',
+ 'aws.credentials.provider.basic.accesskeyid': 'aws_access_key_id',
+ 'aws.credentials.provider.basic.secretkey': 'aws_secret_access_key',
+ 'flink.stream.initpos': 'LATEST'
+ }
+
+ kinesis_source = FlinkKinesisConsumer("stream-1", SimpleStringSchema(), consumer_config)
+
+ ds = self.env.add_source(source_func=kinesis_source, source_name="kinesis source")
+ ds.print()
+ plan = eval(self.env.get_execution_plan())
+ self.assertEqual('Source: kinesis source', plan['nodes'][0]['type'])
+ self.assertEqual(
+ get_field_value(kinesis_source.get_java_function(), 'streams')[0], 'stream-1')
+
+ def test_kinesis_streams_sink(self):
+ sink_properties = {
+ 'aws.region': 'us-east-1',
+ 'aws.credentials.provider.basic.secretkey': 'aws_secret_access_key'
+ }
+
+ ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
+ type_info=Types.ROW([Types.STRING(), Types.INT()]))
+
+ kinesis_streams_sink = KinesisStreamsSink.builder() \
+ .set_kinesis_client_properties(sink_properties) \
+ .set_serialization_schema(SimpleStringSchema()) \
+ .set_partition_key_generator(PartitionKeyGenerator.fixed()) \
+ .set_stream_name("stream-1") \
+ .set_fail_on_error(False) \
+ .set_max_batch_size(500) \
+ .set_max_in_flight_requests(50) \
+ .set_max_buffered_requests(10000) \
+ .set_max_batch_size_in_bytes(5 * 1024 * 1024) \
+ .set_max_time_in_buffer_ms(5000) \
+ .set_max_record_size_in_bytes(1 * 1024 * 1024) \
+ .build()
+
+ ds.sink_to(kinesis_streams_sink).name('kinesis streams sink')
+ plan = eval(self.env.get_execution_plan())
+
+ self.assertEqual('kinesis streams sink: Writer', plan['nodes'][1]['type'])
+ self.assertEqual(get_field_value(kinesis_streams_sink.get_java_function(), 'failOnError'),
+ False)
+ self.assertEqual(
+ get_field_value(kinesis_streams_sink.get_java_function(), 'streamName'), 'stream-1')
+
+ def test_kinesis_firehose_sink(self):
+
+ sink_properties = {
+ 'aws.region': 'eu-west-1',
+ 'aws.credentials.provider.basic.accesskeyid': 'aws_access_key_id',
+ 'aws.credentials.provider.basic.secretkey': 'aws_secret_access_key'
+ }
+
+ ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
+ type_info=Types.ROW([Types.STRING(), Types.INT()]))
+
+ kinesis_firehose_sink = KinesisFirehoseSink.builder() \
+ .set_firehose_client_properties(sink_properties) \
+ .set_serialization_schema(SimpleStringSchema()) \
+ .set_delivery_stream_name('stream-1') \
+ .set_fail_on_error(False) \
+ .set_max_batch_size(500) \
+ .set_max_in_flight_requests(50) \
+ .set_max_buffered_requests(10000) \
+ .set_max_batch_size_in_bytes(5 * 1024 * 1024) \
+ .set_max_time_in_buffer_ms(5000) \
+ .set_max_record_size_in_bytes(1 * 1024 * 1024) \
+ .build()
+
+ ds.sink_to(kinesis_firehose_sink).name('kinesis firehose sink')
+ plan = eval(self.env.get_execution_plan())
+
+ self.assertEqual('kinesis firehose sink: Writer', plan['nodes'][1]['type'])
+ self.assertEqual(get_field_value(kinesis_firehose_sink.get_java_function(), 'failOnError'),
+ False)
+ self.assertEqual(
+ get_field_value(kinesis_firehose_sink.get_java_function(), 'deliveryStreamName'),
+ 'stream-1')
diff --git a/flink-python/pyflink/datastream/connectors/tests/test_pulsar.py b/flink-python/pyflink/datastream/connectors/tests/test_pulsar.py
new file mode 100644
index 00000000000..b27faad9db0
--- /dev/null
+++ b/flink-python/pyflink/datastream/connectors/tests/test_pulsar.py
@@ -0,0 +1,212 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from pyflink.common import WatermarkStrategy, SimpleStringSchema, Types, ConfigOptions, Duration
+from pyflink.datastream.connectors import DeliveryGuarantee
+from pyflink.datastream.connectors.pulsar import PulsarSerializationSchema, TopicRoutingMode, \
+ MessageDelayer, PulsarSink, PulsarSource, StartCursor, PulsarDeserializationSchema, \
+ StopCursor, SubscriptionType
+from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase
+from pyflink.util.java_utils import get_field_value, is_instance_of
+
+
+class FlinkPulsarTest(PyFlinkStreamingTestCase):
+
+ def test_pulsar_source(self):
+ TEST_OPTION_NAME = 'pulsar.source.enableAutoAcknowledgeMessage'
+ pulsar_source = PulsarSource.builder() \
+ .set_service_url('pulsar://localhost:6650') \
+ .set_admin_url('http://localhost:8080') \
+ .set_topics('ada') \
+ .set_start_cursor(StartCursor.earliest()) \
+ .set_unbounded_stop_cursor(StopCursor.never()) \
+ .set_bounded_stop_cursor(StopCursor.at_publish_time(22)) \
+ .set_subscription_name('ff') \
+ .set_subscription_type(SubscriptionType.Exclusive) \
+ .set_deserialization_schema(
+ PulsarDeserializationSchema.flink_type_info(Types.STRING())) \
+ .set_deserialization_schema(
+ PulsarDeserializationSchema.flink_schema(SimpleStringSchema())) \
+ .set_config(TEST_OPTION_NAME, True) \
+ .set_properties({'pulsar.source.autoCommitCursorInterval': '1000'}) \
+ .build()
+
+ ds = self.env.from_source(source=pulsar_source,
+ watermark_strategy=WatermarkStrategy.for_monotonous_timestamps(),
+ source_name="pulsar source")
+ ds.print()
+ plan = eval(self.env.get_execution_plan())
+ self.assertEqual('Source: pulsar source', plan['nodes'][0]['type'])
+
+ configuration = get_field_value(pulsar_source.get_java_function(), "sourceConfiguration")
+ self.assertEqual(
+ configuration.getString(
+ ConfigOptions.key('pulsar.client.serviceUrl')
+ .string_type()
+ .no_default_value()._j_config_option), 'pulsar://localhost:6650')
+ self.assertEqual(
+ configuration.getString(
+ ConfigOptions.key('pulsar.admin.adminUrl')
+ .string_type()
+ .no_default_value()._j_config_option), 'http://localhost:8080')
+ self.assertEqual(
+ configuration.getString(
+ ConfigOptions.key('pulsar.consumer.subscriptionName')
+ .string_type()
+ .no_default_value()._j_config_option), 'ff')
+ self.assertEqual(
+ configuration.getString(
+ ConfigOptions.key('pulsar.consumer.subscriptionType')
+ .string_type()
+ .no_default_value()._j_config_option), SubscriptionType.Exclusive.name)
+ test_option = ConfigOptions.key(TEST_OPTION_NAME).boolean_type().no_default_value()
+ self.assertEqual(
+ configuration.getBoolean(
+ test_option._j_config_option), True)
+ self.assertEqual(
+ configuration.getLong(
+ ConfigOptions.key('pulsar.source.autoCommitCursorInterval')
+ .long_type()
+ .no_default_value()._j_config_option), 1000)
+
+ def test_source_set_topics_with_list(self):
+ PulsarSource.builder() \
+ .set_service_url('pulsar://localhost:6650') \
+ .set_admin_url('http://localhost:8080') \
+ .set_topics(['ada', 'beta']) \
+ .set_subscription_name('ff') \
+ .set_deserialization_schema(
+ PulsarDeserializationSchema.flink_schema(SimpleStringSchema())) \
+ .build()
+
+ def test_source_set_topics_pattern(self):
+ PulsarSource.builder() \
+ .set_service_url('pulsar://localhost:6650') \
+ .set_admin_url('http://localhost:8080') \
+ .set_topic_pattern('ada.*') \
+ .set_subscription_name('ff') \
+ .set_deserialization_schema(
+ PulsarDeserializationSchema.flink_schema(SimpleStringSchema())) \
+ .build()
+
+ def test_source_deprecated_method(self):
+ test_option = ConfigOptions.key('pulsar.source.enableAutoAcknowledgeMessage') \
+ .boolean_type().no_default_value()
+ pulsar_source = PulsarSource.builder() \
+ .set_service_url('pulsar://localhost:6650') \
+ .set_admin_url('http://localhost:8080') \
+ .set_topic_pattern('ada.*') \
+ .set_deserialization_schema(
+ PulsarDeserializationSchema.flink_type_info(Types.STRING())) \
+ .set_unbounded_stop_cursor(StopCursor.at_publish_time(4444)) \
+ .set_subscription_name('ff') \
+ .set_config(test_option, True) \
+ .set_properties({'pulsar.source.autoCommitCursorInterval': '1000'}) \
+ .build()
+ configuration = get_field_value(pulsar_source.get_java_function(), "sourceConfiguration")
+ self.assertEqual(
+ configuration.getBoolean(
+ test_option._j_config_option), True)
+ self.assertEqual(
+ configuration.getLong(
+ ConfigOptions.key('pulsar.source.autoCommitCursorInterval')
+ .long_type()
+ .no_default_value()._j_config_option), 1000)
+
+ def test_pulsar_sink(self):
+ ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)],
+ type_info=Types.ROW([Types.STRING(), Types.INT()]))
+
+ TEST_OPTION_NAME = 'pulsar.producer.chunkingEnabled'
+ pulsar_sink = PulsarSink.builder() \
+ .set_service_url('pulsar://localhost:6650') \
+ .set_admin_url('http://localhost:8080') \
+ .set_producer_name('fo') \
+ .set_topics('ada') \
+ .set_serialization_schema(
+ PulsarSerializationSchema.flink_schema(SimpleStringSchema())) \
+ .set_delivery_guarantee(DeliveryGuarantee.AT_LEAST_ONCE) \
+ .set_topic_routing_mode(TopicRoutingMode.ROUND_ROBIN) \
+ .delay_sending_message(MessageDelayer.fixed(Duration.of_seconds(12))) \
+ .set_config(TEST_OPTION_NAME, True) \
+ .set_properties({'pulsar.producer.batchingMaxMessages': '100'}) \
+ .build()
+
+ ds.sink_to(pulsar_sink).name('pulsar sink')
+
+ plan = eval(self.env.get_execution_plan())
+ self.assertEqual('pulsar sink: Writer', plan['nodes'][1]['type'])
+ configuration = get_field_value(pulsar_sink.get_java_function(), "sinkConfiguration")
+ self.assertEqual(
+ configuration.getString(
+ ConfigOptions.key('pulsar.client.serviceUrl')
+ .string_type()
+ .no_default_value()._j_config_option), 'pulsar://localhost:6650')
+ self.assertEqual(
+ configuration.getString(
+ ConfigOptions.key('pulsar.admin.adminUrl')
+ .string_type()
+ .no_default_value()._j_config_option), 'http://localhost:8080')
+ self.assertEqual(
+ configuration.getString(
+ ConfigOptions.key('pulsar.producer.producerName')
+ .string_type()
+ .no_default_value()._j_config_option), 'fo - %s')
+
+ j_pulsar_serialization_schema = get_field_value(
+ pulsar_sink.get_java_function(), 'serializationSchema')
+ j_serialization_schema = get_field_value(
+ j_pulsar_serialization_schema, 'serializationSchema')
+ self.assertTrue(
+ is_instance_of(
+ j_serialization_schema,
+ 'org.apache.flink.api.common.serialization.SimpleStringSchema'))
+
+ self.assertEqual(
+ configuration.getString(
+ ConfigOptions.key('pulsar.sink.deliveryGuarantee')
+ .string_type()
+ .no_default_value()._j_config_option), 'at-least-once')
+
+ j_topic_router = get_field_value(pulsar_sink.get_java_function(), "topicRouter")
+ self.assertTrue(
+ is_instance_of(
+ j_topic_router,
+ 'org.apache.flink.connector.pulsar.sink.writer.router.RoundRobinTopicRouter'))
+
+ j_message_delayer = get_field_value(pulsar_sink.get_java_function(), 'messageDelayer')
+ delay_duration = get_field_value(j_message_delayer, 'delayDuration')
+ self.assertEqual(delay_duration, 12000)
+
+ test_option = ConfigOptions.key(TEST_OPTION_NAME).boolean_type().no_default_value()
+ self.assertEqual(
+ configuration.getBoolean(
+ test_option._j_config_option), True)
+ self.assertEqual(
+ configuration.getLong(
+ ConfigOptions.key('pulsar.producer.batchingMaxMessages')
+ .long_type()
+ .no_default_value()._j_config_option), 100)
+
+ def test_sink_set_topics_with_list(self):
+ PulsarSink.builder() \
+ .set_service_url('pulsar://localhost:6650') \
+ .set_admin_url('http://localhost:8080') \
+ .set_topics(['ada', 'beta']) \
+ .set_serialization_schema(
+ PulsarSerializationSchema.flink_schema(SimpleStringSchema())) \
+ .build()
diff --git a/flink-python/pyflink/datastream/connectors/tests/test_rabbitmq.py b/flink-python/pyflink/datastream/connectors/tests/test_rabbitmq.py
new file mode 100644
index 00000000000..de2fca63076
--- /dev/null
+++ b/flink-python/pyflink/datastream/connectors/tests/test_rabbitmq.py
@@ -0,0 +1,48 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from pyflink.common import Types, JsonRowDeserializationSchema, JsonRowSerializationSchema
+from pyflink.datastream.connectors.rabbitmq import RMQSink, RMQSource, RMQConnectionConfig
+from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase
+from pyflink.util.java_utils import get_field_value
+
+
+class RMQTest(PyFlinkStreamingTestCase):
+
+ def test_rabbitmq_connectors(self):
+ connection_config = RMQConnectionConfig.Builder() \
+ .set_host('localhost') \
+ .set_port(5672) \
+ .set_virtual_host('/') \
+ .set_user_name('guest') \
+ .set_password('guest') \
+ .build()
+ type_info = Types.ROW([Types.INT(), Types.STRING()])
+ deserialization_schema = JsonRowDeserializationSchema.builder() \
+ .type_info(type_info=type_info).build()
+
+ rmq_source = RMQSource(
+ connection_config, 'source_queue', True, deserialization_schema)
+ self.assertEqual(
+ get_field_value(rmq_source.get_java_function(), 'queueName'), 'source_queue')
+ self.assertTrue(get_field_value(rmq_source.get_java_function(), 'usesCorrelationId'))
+
+ serialization_schema = JsonRowSerializationSchema.builder().with_type_info(type_info) \
+ .build()
+ rmq_sink = RMQSink(connection_config, 'sink_queue', serialization_schema)
+ self.assertEqual(
+ get_field_value(rmq_sink.get_java_function(), 'queueName'), 'sink_queue')
diff --git a/flink-python/pyflink/datastream/connectors/tests/test_seq_source.py b/flink-python/pyflink/datastream/connectors/tests/test_seq_source.py
new file mode 100644
index 00000000000..c2eb76c42b5
--- /dev/null
+++ b/flink-python/pyflink/datastream/connectors/tests/test_seq_source.py
@@ -0,0 +1,36 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from pyflink.datastream.connectors.number_seq import NumberSequenceSource
+from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase
+from pyflink.util.java_utils import load_java_class
+
+
+class SequenceSourceTests(PyFlinkStreamingTestCase):
+
+ def test_seq_source(self):
+ seq_source = NumberSequenceSource(1, 10)
+
+ seq_source_clz = load_java_class(
+ "org.apache.flink.api.connector.source.lib.NumberSequenceSource")
+ from_field = seq_source_clz.getDeclaredField("from")
+ from_field.setAccessible(True)
+ self.assertEqual(1, from_field.get(seq_source.get_java_function()))
+
+ to_field = seq_source_clz.getDeclaredField("to")
+ to_field.setAccessible(True)
+ self.assertEqual(10, to_field.get(seq_source.get_java_function()))
diff --git a/flink-python/pyflink/datastream/formats/tests/__init__.py b/flink-python/pyflink/datastream/formats/tests/__init__.py
new file mode 100644
index 00000000000..65b48d4d79b
--- /dev/null
+++ b/flink-python/pyflink/datastream/formats/tests/__init__.py
@@ -0,0 +1,17 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
diff --git a/flink-python/pyflink/datastream/formats/tests/test_avro.py b/flink-python/pyflink/datastream/formats/tests/test_avro.py
new file mode 100644
index 00000000000..ade29e13dbd
--- /dev/null
+++ b/flink-python/pyflink/datastream/formats/tests/test_avro.py
@@ -0,0 +1,450 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import glob
+import os
+import tempfile
+from typing import Tuple, List
+
+from avro.datafile import DataFileReader
+from avro.io import DatumReader
+from py4j.java_gateway import JavaObject, java_import
+
+from pyflink.datastream import MapFunction
+from pyflink.datastream.connectors.file_system import FileSink
+from pyflink.datastream.formats import AvroSchema, GenericRecordAvroTypeInfo, AvroWriters, \
+ AvroInputFormat
+from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
+from pyflink.java_gateway import get_gateway
+from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase
+
+
+class FileSourceAvroInputFormatTests(PyFlinkStreamingTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.test_sink = DataStreamTestSinkFunction()
+ self.avro_file_name = tempfile.mktemp(suffix='.avro', dir=self.tempdir)
+ _import_avro_classes()
+
+ def test_avro_basic_read(self):
+ schema, records = _create_basic_avro_schema_and_records()
+ self._create_avro_file(schema, records)
+ self._build_avro_job(schema)
+ self.env.execute('test_avro_basic_read')
+ results = self.test_sink.get_results(True, False)
+ _check_basic_avro_schema_results(self, results)
+
+ def test_avro_enum_read(self):
+ schema, records = _create_enum_avro_schema_and_records()
+ self._create_avro_file(schema, records)
+ self._build_avro_job(schema)
+ self.env.execute('test_avro_enum_read')
+ results = self.test_sink.get_results(True, False)
+ _check_enum_avro_schema_results(self, results)
+
+ def test_avro_union_read(self):
+ schema, records = _create_union_avro_schema_and_records()
+ self._create_avro_file(schema, records)
+ self._build_avro_job(schema)
+ self.env.execute('test_avro_union_read')
+ results = self.test_sink.get_results(True, False)
+ _check_union_avro_schema_results(self, results)
+
+ def test_avro_array_read(self):
+ schema, records = _create_array_avro_schema_and_records()
+ self._create_avro_file(schema, records)
+ self._build_avro_job(schema)
+ self.env.execute('test_avro_array_read')
+ results = self.test_sink.get_results(True, False)
+ _check_array_avro_schema_results(self, results)
+
+ def test_avro_map_read(self):
+ schema, records = _create_map_avro_schema_and_records()
+ self._create_avro_file(schema, records)
+ self._build_avro_job(schema)
+ self.env.execute('test_avro_map_read')
+ results = self.test_sink.get_results(True, False)
+ _check_map_avro_schema_results(self, results)
+
+ def _build_avro_job(self, record_schema):
+ ds = self.env.create_input(AvroInputFormat(self.avro_file_name, record_schema))
+ ds.map(PassThroughMapFunction()).add_sink(self.test_sink)
+
+ def _create_avro_file(self, schema: AvroSchema, records: list):
+ jvm = get_gateway().jvm
+ j_file = jvm.java.io.File(self.avro_file_name)
+ j_datum_writer = jvm.org.apache.flink.avro.shaded.org.apache.avro.generic \
+ .GenericDatumWriter()
+ j_file_writer = jvm.org.apache.flink.avro.shaded.org.apache.avro.file \
+ .DataFileWriter(j_datum_writer)
+ j_file_writer.create(schema._j_schema, j_file)
+ for r in records:
+ j_file_writer.append(r)
+ j_file_writer.close()
+
+
+class FileSinkAvroWritersTests(PyFlinkStreamingTestCase):
+
+ def setUp(self):
+ super().setUp()
+ # NOTE: parallelism == 1 is required to keep the order of results
+ self.env.set_parallelism(1)
+ self.avro_dir_name = tempfile.mkdtemp(dir=self.tempdir)
+
+ def test_avro_basic_write(self):
+ schema, objects = _create_basic_avro_schema_and_py_objects()
+ self._build_avro_job(schema, objects)
+ self.env.execute('test_avro_basic_write')
+ results = self._read_avro_file()
+ _check_basic_avro_schema_results(self, results)
+
+ def test_avro_enum_write(self):
+ schema, objects = _create_enum_avro_schema_and_py_objects()
+ self._build_avro_job(schema, objects)
+ self.env.execute('test_avro_enum_write')
+ results = self._read_avro_file()
+ _check_enum_avro_schema_results(self, results)
+
+ def test_avro_union_write(self):
+ schema, objects = _create_union_avro_schema_and_py_objects()
+ self._build_avro_job(schema, objects)
+ self.env.execute('test_avro_union_write')
+ results = self._read_avro_file()
+ _check_union_avro_schema_results(self, results)
+
+ def test_avro_array_write(self):
+ schema, objects = _create_array_avro_schema_and_py_objects()
+ self._build_avro_job(schema, objects)
+ self.env.execute('test_avro_array_write')
+ results = self._read_avro_file()
+ _check_array_avro_schema_results(self, results)
+
+ def test_avro_map_write(self):
+ schema, objects = _create_map_avro_schema_and_py_objects()
+ self._build_avro_job(schema, objects)
+ self.env.execute('test_avro_map_write')
+ results = self._read_avro_file()
+ _check_map_avro_schema_results(self, results)
+
+ def _build_avro_job(self, schema, objects):
+ ds = self.env.from_collection(objects)
+ sink = FileSink.for_bulk_format(
+ self.avro_dir_name, AvroWriters.for_generic_record(schema)
+ ).build()
+ ds.map(lambda e: e, output_type=GenericRecordAvroTypeInfo(schema)).sink_to(sink)
+
+ def _read_avro_file(self) -> List[dict]:
+ records = []
+ for file in glob.glob(os.path.join(os.path.join(self.avro_dir_name, '**/*'))):
+ for record in DataFileReader(open(file, 'rb'), DatumReader()):
+ records.append(record)
+ return records
+
+
+class PassThroughMapFunction(MapFunction):
+
+ def map(self, value):
+ return value
+
+
+def _import_avro_classes():
+ jvm = get_gateway().jvm
+ classes = ['org.apache.avro.generic.GenericData']
+ prefix = 'org.apache.flink.avro.shaded.'
+ for cls in classes:
+ java_import(jvm, prefix + cls)
+
+
+BASIC_SCHEMA = """
+{
+ "type": "record",
+ "name": "test",
+ "fields": [
+ { "name": "null", "type": "null" },
+ { "name": "boolean", "type": "boolean" },
+ { "name": "int", "type": "int" },
+ { "name": "long", "type": "long" },
+ { "name": "float", "type": "float" },
+ { "name": "double", "type": "double" },
+ { "name": "string", "type": "string" }
+ ]
+}
+"""
+
+
+def _create_basic_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
+ schema = AvroSchema.parse_string(BASIC_SCHEMA)
+ records = [_create_basic_avro_record(schema, True, 0, 1, 2, 3, 's1'),
+ _create_basic_avro_record(schema, False, 4, 5, 6, 7, 's2')]
+ return schema, records
+
+
+def _create_basic_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
+ schema = AvroSchema.parse_string(BASIC_SCHEMA)
+ objects = [
+ {'null': None, 'boolean': True, 'int': 0, 'long': 1,
+ 'float': 2., 'double': 3., 'string': 's1'},
+ {'null': None, 'boolean': False, 'int': 4, 'long': 5,
+ 'float': 6., 'double': 7., 'string': 's2'},
+ ]
+ return schema, objects
+
+
+def _check_basic_avro_schema_results(test, results):
+ result1 = results[0]
+ result2 = results[1]
+ test.assertEqual(result1['null'], None)
+ test.assertEqual(result1['boolean'], True)
+ test.assertEqual(result1['int'], 0)
+ test.assertEqual(result1['long'], 1)
+ test.assertAlmostEqual(result1['float'], 2, delta=1e-3)
+ test.assertAlmostEqual(result1['double'], 3, delta=1e-3)
+ test.assertEqual(result1['string'], 's1')
+ test.assertEqual(result2['null'], None)
+ test.assertEqual(result2['boolean'], False)
+ test.assertEqual(result2['int'], 4)
+ test.assertEqual(result2['long'], 5)
+ test.assertAlmostEqual(result2['float'], 6, delta=1e-3)
+ test.assertAlmostEqual(result2['double'], 7, delta=1e-3)
+ test.assertEqual(result2['string'], 's2')
+
+
+ENUM_SCHEMA = """
+{
+ "type": "record",
+ "name": "test",
+ "fields": [
+ {
+ "name": "suit",
+ "type": {
+ "type": "enum",
+ "name": "Suit",
+ "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]
+ }
+ }
+ ]
+}
+"""
+
+
+def _create_enum_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
+ schema = AvroSchema.parse_string(ENUM_SCHEMA)
+ records = [_create_enum_avro_record(schema, 'SPADES'),
+ _create_enum_avro_record(schema, 'DIAMONDS')]
+ return schema, records
+
+
+def _create_enum_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
+ schema = AvroSchema.parse_string(ENUM_SCHEMA)
+ records = [
+ {'suit': 'SPADES'},
+ {'suit': 'DIAMONDS'},
+ ]
+ return schema, records
+
+
+def _check_enum_avro_schema_results(test, results):
+ test.assertEqual(results[0]['suit'], 'SPADES')
+ test.assertEqual(results[1]['suit'], 'DIAMONDS')
+
+
+UNION_SCHEMA = """
+{
+ "type": "record",
+ "name": "test",
+ "fields": [
+ {
+ "name": "union",
+ "type": [ "int", "double", "null" ]
+ }
+ ]
+}
+"""
+
+
+def _create_union_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
+ schema = AvroSchema.parse_string(UNION_SCHEMA)
+ records = [_create_union_avro_record(schema, 1),
+ _create_union_avro_record(schema, 2.),
+ _create_union_avro_record(schema, None)]
+ return schema, records
+
+
+def _create_union_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
+ schema = AvroSchema.parse_string(UNION_SCHEMA)
+ records = [
+ {'union': 1},
+ {'union': 2.},
+ {'union': None},
+ ]
+ return schema, records
+
+
+def _check_union_avro_schema_results(test, results):
+ test.assertEqual(results[0]['union'], 1)
+ test.assertAlmostEqual(results[1]['union'], 2.0, delta=1e-3)
+ test.assertEqual(results[2]['union'], None)
+
+
+# It seems there's bug when array item record contains only one field, which throws
+# java.lang.ClassCastException: required ... is not a group when reading
+ARRAY_SCHEMA = """
+{
+ "type": "record",
+ "name": "test",
+ "fields": [
+ {
+ "name": "array",
+ "type": {
+ "type": "array",
+ "items": {
+ "type": "record",
+ "name": "item",
+ "fields": [
+ { "name": "int", "type": "int" },
+ { "name": "double", "type": "double" }
+ ]
+ }
+ }
+ }
+ ]
+}
+"""
+
+
+def _create_array_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
+ schema = AvroSchema.parse_string(ARRAY_SCHEMA)
+ records = [_create_array_avro_record(schema, [(1, 2.), (3, 4.)]),
+ _create_array_avro_record(schema, [(5, 6.), (7, 8.)])]
+ return schema, records
+
+
+def _create_array_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
+ schema = AvroSchema.parse_string(ARRAY_SCHEMA)
+ records = [
+ {'array': [{'int': 1, 'double': 2.}, {'int': 3, 'double': 4.}]},
+ {'array': [{'int': 5, 'double': 6.}, {'int': 7, 'double': 8.}]},
+ ]
+ return schema, records
+
+
+def _check_array_avro_schema_results(test, results):
+ result1 = results[0]
+ result2 = results[1]
+ test.assertEqual(result1['array'][0]['int'], 1)
+ test.assertAlmostEqual(result1['array'][0]['double'], 2., delta=1e-3)
+ test.assertEqual(result1['array'][1]['int'], 3)
+ test.assertAlmostEqual(result1['array'][1]['double'], 4., delta=1e-3)
+ test.assertEqual(result2['array'][0]['int'], 5)
+ test.assertAlmostEqual(result2['array'][0]['double'], 6., delta=1e-3)
+ test.assertEqual(result2['array'][1]['int'], 7)
+ test.assertAlmostEqual(result2['array'][1]['double'], 8., delta=1e-3)
+
+
+MAP_SCHEMA = """
+{
+ "type": "record",
+ "name": "test",
+ "fields": [
+ {
+ "name": "map",
+ "type": {
+ "type": "map",
+ "values": "long"
+ }
+ }
+ ]
+}
+"""
+
+
+def _create_map_avro_schema_and_records() -> Tuple[AvroSchema, List[JavaObject]]:
+ schema = AvroSchema.parse_string(MAP_SCHEMA)
+ records = [_create_map_avro_record(schema, {'a': 1, 'b': 2}),
+ _create_map_avro_record(schema, {'c': 3, 'd': 4})]
+ return schema, records
+
+
+def _create_map_avro_schema_and_py_objects() -> Tuple[AvroSchema, List[dict]]:
+ schema = AvroSchema.parse_string(MAP_SCHEMA)
+ records = [
+ {'map': {'a': 1, 'b': 2}},
+ {'map': {'c': 3, 'd': 4}},
+ ]
+ return schema, records
+
+
+def _check_map_avro_schema_results(test, results):
+ result1 = results[0]
+ result2 = results[1]
+ test.assertEqual(result1['map']['a'], 1)
+ test.assertEqual(result1['map']['b'], 2)
+ test.assertEqual(result2['map']['c'], 3)
+ test.assertEqual(result2['map']['d'], 4)
+
+
+def _create_basic_avro_record(schema: AvroSchema, boolean_value, int_value, long_value,
+ float_value, double_value, string_value):
+ jvm = get_gateway().jvm
+ j_record = jvm.GenericData.Record(schema._j_schema)
+ j_record.put('boolean', boolean_value)
+ j_record.put('int', int_value)
+ j_record.put('long', long_value)
+ j_record.put('float', float_value)
+ j_record.put('double', double_value)
+ j_record.put('string', string_value)
+ return j_record
+
+
+def _create_enum_avro_record(schema: AvroSchema, enum_value):
+ jvm = get_gateway().jvm
+ j_record = jvm.GenericData.Record(schema._j_schema)
+ j_enum = jvm.GenericData.EnumSymbol(schema._j_schema.getField('suit').schema(), enum_value)
+ j_record.put('suit', j_enum)
+ return j_record
+
+
+def _create_union_avro_record(schema, union_value):
+ jvm = get_gateway().jvm
+ j_record = jvm.GenericData.Record(schema._j_schema)
+ j_record.put('union', union_value)
+ return j_record
+
+
+def _create_array_avro_record(schema, item_values: list):
+ jvm = get_gateway().jvm
+ j_record = jvm.GenericData.Record(schema._j_schema)
+ item_schema = AvroSchema(schema._j_schema.getField('array').schema().getElementType())
+ j_array = jvm.java.util.ArrayList()
+ for idx, item_value in enumerate(item_values):
+ j_item = jvm.GenericData.Record(item_schema._j_schema)
+ j_item.put('int', item_value[0])
+ j_item.put('double', item_value[1])
+ j_array.add(j_item)
+ j_record.put('array', j_array)
+ return j_record
+
+
+def _create_map_avro_record(schema, map: dict):
+ jvm = get_gateway().jvm
+ j_record = jvm.GenericData.Record(schema._j_schema)
+ j_map = jvm.java.util.HashMap()
+ for k, v in map.items():
+ j_map.put(k, v)
+ j_record.put('map', j_map)
+ return j_record
diff --git a/flink-python/pyflink/datastream/formats/tests/test_csv.py b/flink-python/pyflink/datastream/formats/tests/test_csv.py
new file mode 100644
index 00000000000..326095392e4
--- /dev/null
+++ b/flink-python/pyflink/datastream/formats/tests/test_csv.py
@@ -0,0 +1,354 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import glob
+import os
+import tempfile
+from typing import Tuple, List
+
+from pyflink.common import WatermarkStrategy, Types
+from pyflink.datastream import MapFunction
+from pyflink.datastream.connectors.file_system import FileSource, FileSink
+from pyflink.datastream.formats import CsvSchema, CsvReaderFormat, CsvBulkWriter
+from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
+from pyflink.table import DataTypes
+from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase
+
+
+class FileSourceCsvReaderFormatTests(PyFlinkStreamingTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.test_sink = DataStreamTestSinkFunction()
+ self.csv_file_name = tempfile.mktemp(suffix='.csv', dir=self.tempdir)
+
+ def test_csv_primitive_column(self):
+ schema, lines = _create_csv_primitive_column_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_primitive_column')
+ _check_csv_primitive_column_results(self, self.test_sink.get_results(True, False))
+
+ def test_csv_add_columns_from(self):
+ original_schema, lines = _create_csv_primitive_column_schema_and_lines()
+ schema = CsvSchema.builder().add_columns_from(original_schema).build()
+ self._build_csv_job(schema, lines)
+
+ self.env.execute('test_csv_schema_copy')
+ _check_csv_primitive_column_results(self, self.test_sink.get_results(True, False))
+
+ def test_csv_array_column(self):
+ schema, lines = _create_csv_array_column_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_array_column')
+ _check_csv_array_column_results(self, self.test_sink.get_results(True, False))
+
+ def test_csv_allow_comments(self):
+ schema, lines = _create_csv_allow_comments_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_allow_comments')
+ _check_csv_allow_comments_results(self, self.test_sink.get_results(True, False))
+
+ def test_csv_use_header(self):
+ schema, lines = _create_csv_use_header_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_use_header')
+ _check_csv_use_header_results(self, self.test_sink.get_results(True, False))
+
+ def test_csv_strict_headers(self):
+ schema, lines = _create_csv_strict_headers_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_strict_headers')
+ _check_csv_strict_headers_results(self, self.test_sink.get_results(True, False))
+
+ def test_csv_default_quote_char(self):
+ schema, lines = _create_csv_default_quote_char_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_default_quote_char')
+ _check_csv_default_quote_char_results(self, self.test_sink.get_results(True, False))
+
+ def test_csv_customize_quote_char(self):
+ schema, lines = _create_csv_customize_quote_char_schema_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_customize_quote_char')
+ _check_csv_customize_quote_char_results(self, self.test_sink.get_results(True, False))
+
+ def test_csv_use_escape_char(self):
+ schema, lines = _create_csv_set_escape_char_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_use_escape_char')
+ _check_csv_set_escape_char_results(self, self.test_sink.get_results(True, False))
+
+ def _build_csv_job(self, schema, lines):
+ with open(self.csv_file_name, 'w') as f:
+ for line in lines:
+ f.write(line)
+ source = FileSource.for_record_stream_format(
+ CsvReaderFormat.for_schema(schema), self.csv_file_name).build()
+ ds = self.env.from_source(source, WatermarkStrategy.no_watermarks(), 'csv-source')
+ ds.map(PassThroughMapFunction(), output_type=Types.PICKLED_BYTE_ARRAY()) \
+ .add_sink(self.test_sink)
+
+
+class FileSinkCsvBulkWriterTests(PyFlinkStreamingTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.env.set_parallelism(1)
+ self.csv_file_name = tempfile.mktemp(suffix='.csv', dir=self.tempdir)
+ self.csv_dir_name = tempfile.mkdtemp(dir=self.tempdir)
+
+ def test_csv_primitive_column_write(self):
+ schema, lines = _create_csv_primitive_column_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_primitive_column_write')
+ results = self._read_csv_file()
+ self.assertTrue(len(results) == 1)
+ self.assertEqual(
+ results[0],
+ '127,-32767,2147483647,-9223372036854775808,3.0E38,2.0E-308,2,true,string\n'
+ )
+
+ def test_csv_array_column_write(self):
+ schema, lines = _create_csv_array_column_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_array_column_write')
+ results = self._read_csv_file()
+ self.assertTrue(len(results) == 1)
+ self.assertListEqual(results, lines)
+
+ def test_csv_default_quote_char_write(self):
+ schema, lines = _create_csv_default_quote_char_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_default_quote_char_write')
+ results = self._read_csv_file()
+ self.assertTrue(len(results) == 1)
+ self.assertListEqual(results, lines)
+
+ def test_csv_customize_quote_char_write(self):
+ schema, lines = _create_csv_customize_quote_char_schema_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_customize_quote_char_write')
+ results = self._read_csv_file()
+ self.assertTrue(len(results) == 1)
+ self.assertListEqual(results, lines)
+
+ def test_csv_use_escape_char_write(self):
+ schema, lines = _create_csv_set_escape_char_schema_and_lines()
+ self._build_csv_job(schema, lines)
+ self.env.execute('test_csv_use_escape_char_write')
+ results = self._read_csv_file()
+ self.assertTrue(len(results) == 1)
+ self.assertListEqual(results, ['"string,","""string2"""\n'])
+
+ def _build_csv_job(self, schema: CsvSchema, lines):
+ with open(self.csv_file_name, 'w') as f:
+ for line in lines:
+ f.write(line)
+ source = FileSource.for_record_stream_format(
+ CsvReaderFormat.for_schema(schema), self.csv_file_name
+ ).build()
+ ds = self.env.from_source(source, WatermarkStrategy.no_watermarks(), 'csv-source')
+ sink = FileSink.for_bulk_format(
+ self.csv_dir_name, CsvBulkWriter.for_schema(schema)
+ ).build()
+ ds.map(lambda e: e, output_type=schema.get_type_info()).sink_to(sink)
+
+ def _read_csv_file(self) -> List[str]:
+ lines = []
+ for file in glob.glob(os.path.join(self.csv_dir_name, '**/*')):
+ with open(file, 'r') as f:
+ lines.extend(f.readlines())
+ return lines
+
+
+class PassThroughMapFunction(MapFunction):
+
+ def map(self, value):
+ return value
+
+
+def _create_csv_primitive_column_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_number_column('tinyint', DataTypes.TINYINT()) \
+ .add_number_column('smallint', DataTypes.SMALLINT()) \
+ .add_number_column('int', DataTypes.INT()) \
+ .add_number_column('bigint', DataTypes.BIGINT()) \
+ .add_number_column('float', DataTypes.FLOAT()) \
+ .add_number_column('double', DataTypes.DOUBLE()) \
+ .add_number_column('decimal', DataTypes.DECIMAL(2, 0)) \
+ .add_boolean_column('boolean') \
+ .add_string_column('string') \
+ .build()
+ lines = [
+ '127,'
+ '-32767,'
+ '2147483647,'
+ '-9223372036854775808,'
+ '3e38,'
+ '2e-308,'
+ '1.5,'
+ 'true,'
+ 'string\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_primitive_column_results(test, results):
+ row = results[0]
+ test.assertEqual(row['tinyint'], 127)
+ test.assertEqual(row['smallint'], -32767)
+ test.assertEqual(row['int'], 2147483647)
+ test.assertEqual(row['bigint'], -9223372036854775808)
+ test.assertAlmostEqual(row['float'], 3e38, delta=1e31)
+ test.assertAlmostEqual(row['double'], 2e-308, delta=2e-301)
+ test.assertAlmostEqual(row['decimal'], 2)
+ test.assertEqual(row['boolean'], True)
+ test.assertEqual(row['string'], 'string')
+
+
+def _create_csv_array_column_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_array_column('number_array', separator=';', element_type=DataTypes.INT()) \
+ .add_array_column('boolean_array', separator=':', element_type=DataTypes.BOOLEAN()) \
+ .add_array_column('string_array', separator=',', element_type=DataTypes.STRING()) \
+ .set_column_separator('|') \
+ .disable_quote_char() \
+ .build()
+ lines = [
+ '1;2;3|'
+ 'true:false|'
+ 'a,b,c\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_array_column_results(test, results):
+ row = results[0]
+ test.assertListEqual(row['number_array'], [1, 2, 3])
+ test.assertListEqual(row['boolean_array'], [True, False])
+ test.assertListEqual(row['string_array'], ['a', 'b', 'c'])
+
+
+def _create_csv_allow_comments_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_string_column('string') \
+ .set_allow_comments() \
+ .build()
+ lines = [
+ 'a\n',
+ '# this is comment\n',
+ 'b\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_allow_comments_results(test, results):
+ test.assertEqual(results[0]['string'], 'a')
+ test.assertEqual(results[1]['string'], 'b')
+
+
+def _create_csv_use_header_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_string_column('string') \
+ .add_number_column('number') \
+ .set_use_header() \
+ .build()
+ lines = [
+ 'h1,h2\n',
+ 'string,123\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_use_header_results(test, results):
+ row = results[0]
+ test.assertEqual(row['string'], 'string')
+ test.assertEqual(row['number'], 123)
+
+
+def _create_csv_strict_headers_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_string_column('string') \
+ .add_number_column('number') \
+ .set_use_header() \
+ .set_strict_headers() \
+ .build()
+ lines = [
+ 'string,number\n',
+ 'string,123\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_strict_headers_results(test, results):
+ row = results[0]
+ test.assertEqual(row['string'], 'string')
+ test.assertEqual(row['number'], 123)
+
+
+def _create_csv_default_quote_char_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_string_column('string') \
+ .add_string_column('string2') \
+ .set_column_separator('|') \
+ .build()
+ lines = [
+ '"string"|"string2"\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_default_quote_char_results(test, results):
+ row = results[0]
+ test.assertEqual(row['string'], 'string')
+
+
+def _create_csv_customize_quote_char_schema_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_string_column('string') \
+ .add_string_column('string2') \
+ .set_column_separator('|') \
+ .set_quote_char('`') \
+ .build()
+ lines = [
+ '`string`|`string2`\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_customize_quote_char_results(test, results):
+ row = results[0]
+ test.assertEqual(row['string'], 'string')
+
+
+def _create_csv_set_escape_char_schema_and_lines() -> Tuple[CsvSchema, List[str]]:
+ schema = CsvSchema.builder() \
+ .add_string_column('string') \
+ .add_string_column('string2') \
+ .set_column_separator(',') \
+ .set_escape_char('\\') \
+ .build()
+ lines = [
+ 'string\\,,\\"string2\\"\n',
+ ]
+ return schema, lines
+
+
+def _check_csv_set_escape_char_results(test, results):
+ row = results[0]
+ test.assertEqual(row['string'], 'string,')
+ test.assertEqual(row['string2'], '"string2"')
diff --git a/flink-python/pyflink/datastream/formats/tests/test_parquet.py b/flink-python/pyflink/datastream/formats/tests/test_parquet.py
new file mode 100644
index 00000000000..97e8ffaf0a4
--- /dev/null
+++ b/flink-python/pyflink/datastream/formats/tests/test_parquet.py
@@ -0,0 +1,231 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import glob
+import os
+import tempfile
+import unittest
+from typing import List
+
+from pyflink.common import Configuration
+from pyflink.common.watermark_strategy import WatermarkStrategy
+from pyflink.datastream import MapFunction
+from pyflink.datastream.connectors.file_system import FileSource, FileSink
+from pyflink.datastream.formats.tests.test_avro import \
+ _create_basic_avro_schema_and_py_objects, _check_basic_avro_schema_results, \
+ _create_enum_avro_schema_and_py_objects, _check_enum_avro_schema_results, \
+ _create_union_avro_schema_and_py_objects, _check_union_avro_schema_results, \
+ _create_array_avro_schema_and_py_objects, _check_array_avro_schema_results, \
+ _create_map_avro_schema_and_py_objects, _check_map_avro_schema_results, \
+ _create_map_avro_schema_and_records, _create_array_avro_schema_and_records, \
+ _create_union_avro_schema_and_records, _create_enum_avro_schema_and_records, \
+ _create_basic_avro_schema_and_records, _import_avro_classes
+from pyflink.datastream.formats import GenericRecordAvroTypeInfo, AvroSchema
+from pyflink.datastream.formats.parquet import AvroParquetReaders, ParquetColumnarRowInputFormat, \
+ AvroParquetWriters
+from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction
+from pyflink.java_gateway import get_gateway
+from pyflink.table.types import RowType, DataTypes
+from pyflink.testing.test_case_utils import PyFlinkStreamingTestCase
+
+
+@unittest.skipIf(os.environ.get('HADOOP_CLASSPATH') is None,
+ 'Some Hadoop lib is needed for Parquet-Avro format tests')
+class FileSourceAvroParquetReadersTests(PyFlinkStreamingTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.test_sink = DataStreamTestSinkFunction()
+ _import_avro_classes()
+
+ def test_parquet_avro_basic(self):
+ parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
+ schema, records = _create_basic_avro_schema_and_records()
+ self._create_parquet_avro_file(parquet_file_name, schema, records)
+ self._build_parquet_avro_job(schema, parquet_file_name)
+ self.env.execute("test_parquet_avro_basic")
+ results = self.test_sink.get_results(True, False)
+ _check_basic_avro_schema_results(self, results)
+
+ def test_parquet_avro_enum(self):
+ parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
+ schema, records = _create_enum_avro_schema_and_records()
+ self._create_parquet_avro_file(parquet_file_name, schema, records)
+ self._build_parquet_avro_job(schema, parquet_file_name)
+ self.env.execute("test_parquet_avro_enum")
+ results = self.test_sink.get_results(True, False)
+ _check_enum_avro_schema_results(self, results)
+
+ def test_parquet_avro_union(self):
+ parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
+ schema, records = _create_union_avro_schema_and_records()
+ self._create_parquet_avro_file(parquet_file_name, schema, records)
+ self._build_parquet_avro_job(schema, parquet_file_name)
+ self.env.execute("test_parquet_avro_union")
+ results = self.test_sink.get_results(True, False)
+ _check_union_avro_schema_results(self, results)
+
+ def test_parquet_avro_array(self):
+ parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
+ schema, records = _create_array_avro_schema_and_records()
+ self._create_parquet_avro_file(parquet_file_name, schema, records)
+ self._build_parquet_avro_job(schema, parquet_file_name)
+ self.env.execute("test_parquet_avro_array")
+ results = self.test_sink.get_results(True, False)
+ _check_array_avro_schema_results(self, results)
+
+ def test_parquet_avro_map(self):
+ parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
+ schema, records = _create_map_avro_schema_and_records()
+ self._create_parquet_avro_file(parquet_file_name, schema, records)
+ self._build_parquet_avro_job(schema, parquet_file_name)
+ self.env.execute("test_parquet_avro_map")
+ results = self.test_sink.get_results(True, False)
+ _check_map_avro_schema_results(self, results)
+
+ def _build_parquet_avro_job(self, record_schema, *parquet_file_name):
+ ds = self.env.from_source(
+ FileSource.for_record_stream_format(
+ AvroParquetReaders.for_generic_record(record_schema),
+ *parquet_file_name
+ ).build(),
+ WatermarkStrategy.for_monotonous_timestamps(),
+ "parquet-source"
+ )
+ ds.map(PassThroughMapFunction()).add_sink(self.test_sink)
+
+ @staticmethod
+ def _create_parquet_avro_file(file_path: str, schema: AvroSchema, records: list):
+ jvm = get_gateway().jvm
+ j_path = jvm.org.apache.flink.core.fs.Path(file_path)
+ writer = jvm.org.apache.flink.formats.parquet.avro.AvroParquetWriters \
+ .forGenericRecord(schema._j_schema) \
+ .create(j_path.getFileSystem().create(
+ j_path,
+ jvm.org.apache.flink.core.fs.FileSystem.WriteMode.OVERWRITE
+ ))
+ for record in records:
+ writer.addElement(record)
+ writer.flush()
+ writer.finish()
+
+
+@unittest.skipIf(os.environ.get('HADOOP_CLASSPATH') is None,
+ 'Some Hadoop lib is needed for Parquet-Avro format tests')
+class FileSinkAvroParquetWritersTests(PyFlinkStreamingTestCase):
+
+ def setUp(self):
+ super().setUp()
+ # NOTE: parallelism == 1 is required to keep the order of results
+ self.env.set_parallelism(1)
+ self.parquet_dir_name = tempfile.mkdtemp(dir=self.tempdir)
+ self.test_sink = DataStreamTestSinkFunction()
+
+ def test_parquet_avro_basic_write(self):
+ schema, objects = _create_basic_avro_schema_and_py_objects()
+ self._build_avro_parquet_job(schema, objects)
+ self.env.execute('test_parquet_avro_basic_write')
+ results = self._read_parquet_avro_file(schema)
+ _check_basic_avro_schema_results(self, results)
+
+ def test_parquet_avro_enum_write(self):
+ schema, objects = _create_enum_avro_schema_and_py_objects()
+ self._build_avro_parquet_job(schema, objects)
+ self.env.execute('test_parquet_avro_enum_write')
+ results = self._read_parquet_avro_file(schema)
+ _check_enum_avro_schema_results(self, results)
+
+ def test_parquet_avro_union_write(self):
+ schema, objects = _create_union_avro_schema_and_py_objects()
+ self._build_avro_parquet_job(schema, objects)
+ self.env.execute('test_parquet_avro_union_write')
+ results = self._read_parquet_avro_file(schema)
+ _check_union_avro_schema_results(self, results)
+
+ def test_parquet_avro_array_write(self):
+ schema, objects = _create_array_avro_schema_and_py_objects()
+ self._build_avro_parquet_job(schema, objects)
+ self.env.execute('test_parquet_avro_array_write')
+ results = self._read_parquet_avro_file(schema)
+ _check_array_avro_schema_results(self, results)
+
+ def test_parquet_avro_map_write(self):
+ schema, objects = _create_map_avro_schema_and_py_objects()
+ self._build_avro_parquet_job(schema, objects)
+ self.env.execute('test_parquet_avro_map_write')
+ results = self._read_parquet_avro_file(schema)
+ _check_map_avro_schema_results(self, results)
+
+ def _build_avro_parquet_job(self, schema, objects):
+ ds = self.env.from_collection(objects)
+ avro_type_info = GenericRecordAvroTypeInfo(schema)
+ sink = FileSink.for_bulk_format(
+ self.parquet_dir_name, AvroParquetWriters.for_generic_record(schema)
+ ).build()
+ ds.map(lambda e: e, output_type=avro_type_info).sink_to(sink)
+
+ def _read_parquet_avro_file(self, schema) -> List[dict]:
+ parquet_files = [f for f in glob.glob(self.parquet_dir_name, recursive=True)]
+ FileSourceAvroParquetReadersTests._build_parquet_avro_job(self, schema, *parquet_files)
+ self.env.execute()
+ return self.test_sink.get_results(True, False)
+
+
+@unittest.skipIf(os.environ.get('HADOOP_CLASSPATH') is None,
+ 'Some Hadoop lib is needed for Parquet Columnar format tests')
+class FileSourceParquetColumnarRowInputFormatTests(PyFlinkStreamingTestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.test_sink = DataStreamTestSinkFunction()
+ _import_avro_classes()
+
+ def test_parquet_columnar_basic(self):
+ parquet_file_name = tempfile.mktemp(suffix='.parquet', dir=self.tempdir)
+ schema, records = _create_basic_avro_schema_and_records()
+ FileSourceAvroParquetReadersTests._create_parquet_avro_file(
+ parquet_file_name, schema, records)
+ row_type = DataTypes.ROW([
+ DataTypes.FIELD('null', DataTypes.STRING()), # DataTypes.NULL cannot be serialized
+ DataTypes.FIELD('boolean', DataTypes.BOOLEAN()),
+ DataTypes.FIELD('int', DataTypes.INT()),
+ DataTypes.FIELD('long', DataTypes.BIGINT()),
+ DataTypes.FIELD('float', DataTypes.FLOAT()),
+ DataTypes.FIELD('double', DataTypes.DOUBLE()),
+ DataTypes.FIELD('string', DataTypes.STRING()),
+ DataTypes.FIELD('unknown', DataTypes.STRING())
+ ])
+ self._build_parquet_columnar_job(row_type, parquet_file_name)
+ self.env.execute('test_parquet_columnar_basic')
+ results = self.test_sink.get_results(True, False)
+ _check_basic_avro_schema_results(self, results)
+ self.assertIsNone(results[0]['unknown'])
+ self.assertIsNone(results[1]['unknown'])
+
+ def _build_parquet_columnar_job(self, row_type: RowType, parquet_file_name: str):
+ source = FileSource.for_bulk_file_format(
+ ParquetColumnarRowInputFormat(Configuration(), row_type, 10, True, True),
+ parquet_file_name
+ ).build()
+ ds = self.env.from_source(source, WatermarkStrategy.no_watermarks(), 'parquet-source')
+ ds.map(PassThroughMapFunction()).add_sink(self.test_sink)
+
+
+class PassThroughMapFunction(MapFunction):
+
+ def map(self, value):
+ return value