You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2020/05/16 21:04:52 UTC

[flink] 01/06: [FLINK-7267][connectors/rabbitmq] Allow overriding RMQSource connection

This is an automated email from the ASF dual-hosted git repository.

sewen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit cce715b78ca85b8eee258f32b1e6fb366ca56998
Author: austin ce <au...@gmail.com>
AuthorDate: Mon May 11 19:40:51 2020 -0400

    [FLINK-7267][connectors/rabbitmq] Allow overriding RMQSource connection
---
 .../streaming/connectors/rabbitmq/RMQSource.java   | 13 ++++++++---
 .../connectors/rabbitmq/RMQSourceTest.java         | 26 +++++++++++++++++++++-
 2 files changed, 35 insertions(+), 4 deletions(-)

diff --git a/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSource.java b/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSource.java
index ab7bdc4..e32ce2d 100644
--- a/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSource.java
+++ b/flink-connectors/flink-connector-rabbitmq/src/main/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSource.java
@@ -127,13 +127,21 @@ public class RMQSource<OUT> extends MultipleIdsMessageAcknowledgingSourceBase<OU
 
 	/**
 	 * Initializes the connection to RMQ with a default connection factory. The user may override
-	 * this method to setup and configure their own ConnectionFactory.
+	 * this method to setup and configure their own {@link ConnectionFactory}.
 	 */
 	protected ConnectionFactory setupConnectionFactory() throws Exception {
 		return rmqConnectionConfig.getConnectionFactory();
 	}
 
 	/**
+	 * Initializes the connection to RMQ using the default connection factory from {@link #setupConnectionFactory()}.
+	 * The user may override this method to setup and configure their own {@link Connection}.
+	 */
+	protected Connection setupConnection() throws Exception {
+		return setupConnectionFactory().newConnection();
+	}
+
+	/**
 	 * Sets up the queue. The default implementation just declares the queue. The user may override
 	 * this method to have a custom setup for the queue (i.e. binding the queue to an exchange or
 	 * defining custom queue parameters)
@@ -145,9 +153,8 @@ public class RMQSource<OUT> extends MultipleIdsMessageAcknowledgingSourceBase<OU
 	@Override
 	public void open(Configuration config) throws Exception {
 		super.open(config);
-		ConnectionFactory factory = setupConnectionFactory();
 		try {
-			connection = factory.newConnection();
+			connection = setupConnection();
 			channel = connection.createChannel();
 			if (channel == null) {
 				throw new RuntimeException("None of RabbitMQ channels are available");
diff --git a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
index 925457f..b53723c 100644
--- a/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
+++ b/flink-connectors/flink-connector-rabbitmq/src/test/java/org/apache/flink/streaming/connectors/rabbitmq/RMQSourceTest.java
@@ -329,6 +329,26 @@ public class RMQSourceTest {
 		assertThat("Open method was not called", deserializationSchema.isOpenCalled(), is(true));
 	}
 
+	@Test
+	public void testOverrideConnection() throws Exception {
+		final Connection mockConnection = Mockito.mock(Connection.class);
+		Channel channel = Mockito.mock(Channel.class);
+		Mockito.when(mockConnection.createChannel()).thenReturn(channel);
+
+		RMQMockedRuntimeTestSource source = new RMQMockedRuntimeTestSource() {
+			@Override
+			protected Connection setupConnection() throws Exception {
+				return mockConnection;
+			}
+		};
+
+		FunctionInitializationContext mockContext = getMockContext();
+		source.initializeState(mockContext);
+		source.open(new Configuration());
+
+		Mockito.verify(mockConnection, Mockito.times(1)).createChannel();
+	}
+
 	private static class ConstructorTestClass extends RMQSource<String> {
 
 		private ConnectionFactory factory;
@@ -411,6 +431,10 @@ public class RMQSourceTest {
 			this(connectionConfig, new StringDeserializationScheme());
 		}
 
+		public RMQMockedRuntimeTestSource() {
+			this(new StringDeserializationScheme());
+		}
+
 		@Override
 		public RuntimeContext getRuntimeContext() {
 			return runtimeContext;
@@ -421,7 +445,7 @@ public class RMQSourceTest {
 		private ArrayDeque<Tuple2<Long, Set<String>>> restoredState;
 
 		public RMQTestSource() {
-			this(new StringDeserializationScheme());
+			super();
 		}
 
 		public RMQTestSource(DeserializationSchema<String> deserializationSchema) {