You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2015/09/21 23:34:45 UTC

[4/4] flink git commit: [FLINK-2710] [streaming] Fix and improve SocketTextStreamFunction and SocketTextStreamFunctionTest

[FLINK-2710] [streaming] Fix and improve SocketTextStreamFunction and SocketTextStreamFunctionTest


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/bd74baef
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/bd74baef
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/bd74baef

Branch: refs/heads/master
Commit: bd74baef5d0d23ce29ff71bb67f12bd32219ff7b
Parents: 63d9800
Author: Stephan Ewen <se...@apache.org>
Authored: Mon Sep 21 19:37:18 2015 +0200
Committer: Stephan Ewen <se...@apache.org>
Committed: Mon Sep 21 23:18:04 2015 +0200

----------------------------------------------------------------------
 .../source/SocketTextStreamFunction.java        | 189 ++++-----
 .../source/SocketTextStreamFunctionTest.java    | 420 +++++++++++++------
 2 files changed, 383 insertions(+), 226 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/bd74baef/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/SocketTextStreamFunction.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/SocketTextStreamFunction.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/SocketTextStreamFunction.java
index 6e7bcf6..9310b71 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/SocketTextStreamFunction.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/functions/source/SocketTextStreamFunction.java
@@ -17,133 +17,130 @@
 
 package org.apache.flink.streaming.api.functions.source;
 
-import java.io.BufferedReader;
-import java.io.IOException;
-import java.io.InputStreamReader;
-import java.net.ConnectException;
-import java.net.InetSocketAddress;
-import java.net.Socket;
-import java.net.SocketException;
+import org.apache.flink.runtime.util.IOUtils;
 
-import org.apache.flink.configuration.Configuration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-public class SocketTextStreamFunction extends RichSourceFunction<String> {
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.net.InetSocketAddress;
+import java.net.Socket;
 
-	protected static final Logger LOG = LoggerFactory.getLogger(SocketTextStreamFunction.class);
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+
+/**
+ * A source function that reads strings from a socket. The source will read bytes from the socket stream
+ * and convert them to characters, each byte individually. When the delimiter character is received,
+ * the function will output the current string, and begin a new string.
+ * <p>
+ * The function strips trailing <i>carriage return</i> characters (\r) when the delimiter is the
+ * newline character (\n).
+ * <p>
+ * The function can be set to reconnect to the server socket in case that the stream is closed on the server side.
+ */
+public class SocketTextStreamFunction implements SourceFunction<String> {
 
 	private static final long serialVersionUID = 1L;
+	
+	private static final Logger LOG = LoggerFactory.getLogger(SocketTextStreamFunction.class);
 
-	private String hostname;
-	private int port;
-	private char delimiter;
-	private long maxRetry;
-	private boolean retryForever;
-	private Socket socket;
-	private static final int CONNECTION_TIMEOUT_TIME = 0;
-	static int CONNECTION_RETRY_SLEEP = 1000;
-	protected long retries;
-
-	private volatile boolean isRunning;
+	/** Default delay between successive connection attempts */
+	private static final int DEFAULT_CONNECTION_RETRY_SLEEP = 500;
 
-	public SocketTextStreamFunction(String hostname, int port, char delimiter, long maxRetry) {
-		this.hostname = hostname;
+	/** Default connection timeout when connecting to the server socket (infinite) */
+	private static final int CONNECTION_TIMEOUT_TIME = 0;
+	
+	
+	private final String hostname;
+	private final int port;
+	private final char delimiter;
+	private final long maxNumRetries;
+	private final long delayBetweenRetries;
+	
+	private transient Socket currentSocket;
+	
+	private volatile boolean isRunning = true;
+
+	
+	public SocketTextStreamFunction(String hostname, int port, char delimiter, long maxNumRetries) {
+		this(hostname, port, delimiter, maxNumRetries, DEFAULT_CONNECTION_RETRY_SLEEP);
+	}
+	
+	public SocketTextStreamFunction(String hostname, int port, char delimiter, long maxNumRetries, long delayBetweenRetries) {
+		checkArgument(port > 0 && port < 65536, "port is out of range");
+		checkArgument(maxNumRetries >= -1, "maxNumRetries must be zero or larger (num retries), or -1 (infinite retries)");
+		checkArgument(delayBetweenRetries >= 0, "delayBetweenRetries must be zero or positive");
+		
+		this.hostname = checkNotNull(hostname, "hostname must not be null");
 		this.port = port;
 		this.delimiter = delimiter;
-		this.maxRetry = maxRetry;
-		this.retryForever = maxRetry < 0;
-	}
-
-	@Override
-	public void open(Configuration parameters) throws Exception {
-		super.open(parameters);
-		socket = new Socket();
-		socket.connect(new InetSocketAddress(hostname, port), CONNECTION_TIMEOUT_TIME);
-		isRunning = true;
+		this.maxNumRetries = maxNumRetries;
+		this.delayBetweenRetries = delayBetweenRetries;
 	}
 
 	@Override
 	public void run(SourceContext<String> ctx) throws Exception {
-		streamFromSocket(ctx, socket);
-	}
+		final StringBuilder buffer = new StringBuilder();
+		long attempt = 0;
+		
+		while (isRunning) {
+			
+			try (Socket socket = new Socket()) {
+				currentSocket = socket;
+				
+				LOG.info("Connecting to server socket " + hostname + ':' + port);
+				socket.connect(new InetSocketAddress(hostname, port), CONNECTION_TIMEOUT_TIME);
+				BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream()));
 
-	private void streamFromSocket(SourceContext<String> ctx, Socket socket) throws Exception {
-		try {
-			StringBuilder buffer = new StringBuilder();
-			BufferedReader reader = new BufferedReader(new InputStreamReader(
-					socket.getInputStream()));
-
-			while (isRunning) {
 				int data;
-				try {
-					data = reader.read();
-				} catch (SocketException e) {
-					if (!isRunning) {
-						break;
-					} else {
-						throw e;
+				while (isRunning && (data = reader.read()) != -1) {
+					// check if the string is complete
+					if (data != delimiter) {
+						buffer.append((char) data);
 					}
-				}
-
-				if (data == -1) {
-					socket.close();
-					boolean success = false;
-					retries = 0;
-					while ((retries < maxRetry  || retryForever) && !success) {
-						if (!retryForever) {
-							retries++;
-						}
-						LOG.warn("Lost connection to server socket. Retrying in "
-								+ (CONNECTION_RETRY_SLEEP / 1000) + " seconds...");
-						try {
-							socket = new Socket();
-							socket.connect(new InetSocketAddress(hostname, port),
-									CONNECTION_TIMEOUT_TIME);
-							success = true;
-						} catch (ConnectException ce) {
-							Thread.sleep(CONNECTION_RETRY_SLEEP);
-							socket.close();
+					else {
+						// truncate trailing carriage return
+						if (delimiter == '\n' && buffer.length() > 0 && buffer.charAt(buffer.length() - 1) == '\r') {
+							buffer.setLength(buffer.length() - 1);
 						}
+						ctx.collect(buffer.toString());
+						buffer.setLength(0);
 					}
-
-					if (success) {
-						LOG.info("Server socket is reconnected.");
-					} else {
-						LOG.error("Could not reconnect to server socket.");
-						break;
-					}
-					reader = new BufferedReader(new InputStreamReader(socket.getInputStream()));
-					continue;
 				}
+			}
 
-				if (data == delimiter) {
-					ctx.collect(buffer.toString());
-					buffer = new StringBuilder();
-				} else if (data != '\r') { // ignore carriage return
-					buffer.append((char) data);
+			// if we dropped out of this loop due to an EOF, sleep and retry
+			if (isRunning) {
+				attempt++;
+				if (maxNumRetries == -1 || attempt < maxNumRetries) {
+					LOG.warn("Lost connection to server socket. Retrying in " + delayBetweenRetries + " msecs...");
+					Thread.sleep(delayBetweenRetries);
+				}
+				else {
+					// this should probably be here, but some examples expect simple exists of the stream source
+					// throw new EOFException("Reached end of stream and reconnects are not enabled.");
+					break;
 				}
 			}
+		}
 
-			if (buffer.length() > 0) {
-				ctx.collect(buffer.toString());
-			}
-		} finally {
-			socket.close();
+		// collect trailing data
+		if (buffer.length() > 0) {
+			ctx.collect(buffer.toString());
 		}
 	}
 
 	@Override
 	public void cancel() {
 		isRunning = false;
-		if (socket != null && !socket.isClosed()) {
-			try {
-				socket.close();
-			} catch (IOException e) {
-				if (LOG.isErrorEnabled()) {
-					LOG.error("Could not close open socket");
-				}
-			}
+		
+		// we need to close the socket as well, because the Thread.interrupt() function will
+		// not wake the thread in the socketStream.read() method when blocked.
+		Socket theSocket = this.currentSocket;
+		if (theSocket != null) {
+			IOUtils.closeSocket(theSocket);
 		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/bd74baef/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/functions/source/SocketTextStreamFunctionTest.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/functions/source/SocketTextStreamFunctionTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/functions/source/SocketTextStreamFunctionTest.java
index 5f16c00..3398451 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/functions/source/SocketTextStreamFunctionTest.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/functions/source/SocketTextStreamFunctionTest.java
@@ -18,172 +18,332 @@
 
 package org.apache.flink.streaming.api.functions.source;
 
-import java.io.DataOutputStream;
-import java.net.Socket;
+import org.apache.commons.io.IOUtils;
 
-import org.apache.flink.configuration.Configuration;
-import org.junit.Test;
-import org.mockito.ArgumentCaptor;
-import org.mockito.Mockito;
+import org.apache.flink.streaming.api.watermark.Watermark;
 
-import static java.lang.Thread.sleep;
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.verify;
+import org.junit.Test;
 
+import java.io.EOFException;
+import java.io.OutputStreamWriter;
 import java.net.ServerSocket;
-import java.util.concurrent.atomic.AtomicReference;
+import java.net.Socket;
+
+import static org.junit.Assert.*;
 
 /**
  * Tests for the {@link org.apache.flink.streaming.api.functions.source.SocketTextStreamFunction}.
  */
-public class SocketTextStreamFunctionTest{
+public class SocketTextStreamFunctionTest {
 
-	final AtomicReference<Throwable> error = new AtomicReference<Throwable>();
-	private final String host = "127.0.0.1";
-	private final SourceFunction.SourceContext<String> ctx = Mockito.mock(SourceFunction.SourceContext.class);
+	private static final String LOCALHOST = "127.0.0.1";
 
-	public SocketTextStreamFunctionTest() {
+
+	@Test
+	public void testSocketSourceSimpleOutput() throws Exception {
+		ServerSocket server = new ServerSocket(0);
+		Socket channel = null;
+		
+		try {
+			SocketTextStreamFunction source = new SocketTextStreamFunction(LOCALHOST, server.getLocalPort(), '\n', 0);
+	
+			SocketSourceThread runner = new SocketSourceThread(source, "test1", "check");
+			runner.start();
+	
+			channel = server.accept();
+			OutputStreamWriter writer = new OutputStreamWriter(channel.getOutputStream());
+			
+			writer.write("test1\n");
+			writer.write("check\n");
+			writer.flush();
+			runner.waitForNumElements(2);
+
+			runner.cancel();
+			runner.interrupt();
+			
+			runner.waitUntilDone();
+			
+			channel.close();
+		}
+		finally {
+			if (channel != null) {
+				IOUtils.closeQuietly(channel);
+			}
+			IOUtils.closeQuietly(server);
+		}
 	}
 
-	class SocketSource extends Thread {
+	@Test
+	public void testExitNoRetries() throws Exception {
+		ServerSocket server = new ServerSocket(0);
+		Socket channel = null;
 
-		SocketTextStreamFunction socketSource;
+		try {
+			SocketTextStreamFunction source = new SocketTextStreamFunction(LOCALHOST, server.getLocalPort(), '\n', 0);
 
-		public SocketSource(ServerSocket serverSo, int maxRetry) throws Exception {
-			this.socketSource =  new SocketTextStreamFunction(host, serverSo.getLocalPort(), '\n', maxRetry);
-		}
+			SocketSourceThread runner = new SocketSourceThread(source);
+			runner.start();
 
-		public void run() {
+			channel = server.accept();
+			channel.close();
+			
 			try {
-				this.socketSource.open(new Configuration());
-				this.socketSource.run(ctx);
-			}catch(Exception e){
-				error.set(e);
+				runner.waitUntilDone();
+			}
+			catch (Exception e) {
+				assertTrue(e.getCause() instanceof EOFException);
 			}
 		}
-
-		public void cancel(){
-			this.socketSource.cancel();
+		finally {
+			if (channel != null) {
+				IOUtils.closeQuietly(channel);
+			}
+			IOUtils.closeQuietly(server);
 		}
 	}
 
 	@Test
-	public void testSocketSourceRetryForever() throws Exception{
-		error.set(null);
-		ServerSocket serverSo = new ServerSocket(0);
-		SocketSource source = new SocketSource(serverSo, -1);
-		source.start();
-
-		int count = 0;
-		Socket channel;
-		while (count < 100) {
-			channel = serverSo.accept();
-			count++;
+	public void testSocketSourceOutputWithRetries() throws Exception {
+		ServerSocket server = new ServerSocket(0);
+		Socket channel = null;
+
+		try {
+			SocketTextStreamFunction source = new SocketTextStreamFunction(LOCALHOST, server.getLocalPort(), '\n', 10, 100);
+
+			SocketSourceThread runner = new SocketSourceThread(source, "test1", "check");
+			runner.start();
+
+			// first connection: nothing
+			channel = server.accept();
 			channel.close();
-			assertEquals(0, source.socketSource.retries);
-		}
-		source.cancel();
 
-		if (error.get() != null) {
-			Throwable t = error.get();
-			t.printStackTrace();
-			fail("Error in spawned thread: " + t.getMessage());
-		}
+			// second connection: first string
+			channel = server.accept();
+			OutputStreamWriter writer = new OutputStreamWriter(channel.getOutputStream());
+			writer.write("test1\n");
+			writer.close();
+			channel.close();
+
+			// third connection: nothing
+			channel = server.accept();
+			channel.close();
+
+			// forth connection: second string
+			channel = server.accept();
+			writer = new OutputStreamWriter(channel.getOutputStream());
+			writer.write("check\n");
+			writer.flush();
 
-		assertEquals(100, count);
+			runner.waitForNumElements(2);
+			runner.cancel();
+			runner.waitUntilDone();
+		}
+		finally {
+			if (channel != null) {
+				IOUtils.closeQuietly(channel);
+			}
+			IOUtils.closeQuietly(server);
+		}
 	}
 
 	@Test
-	public void testSocketSourceRetryTenTimes() throws Exception{
-		error.set(null);
-		ServerSocket serverSo = new ServerSocket(0);
-		SocketSource source = new SocketSource(serverSo, 10);
-		source.socketSource.CONNECTION_RETRY_SLEEP = 200;
-
-		assertEquals(0, source.socketSource.retries);
-
-		source.start();
-
-		Socket channel;
-		channel = serverSo.accept();
-		channel.close();
-		serverSo.close();
-		while(source.socketSource.retries < 10){
-			long lastRetry = source.socketSource.retries;
-			sleep(100);
-			assertTrue(source.socketSource.retries >= lastRetry);
-		};
-		assertEquals(10, source.socketSource.retries);
-		source.cancel();
-
-		if (error.get() != null) {
-			Throwable t = error.get();
-			t.printStackTrace();
-			fail("Error in spawned thread: " + t.getMessage());
-		}
+	public void testSocketSourceOutputInfiniteRetries() throws Exception {
+		ServerSocket server = new ServerSocket(0);
+		Socket channel = null;
+
+		try {
+			SocketTextStreamFunction source = new SocketTextStreamFunction(LOCALHOST, server.getLocalPort(), '\n', -1, 100);
+
+			SocketSourceThread runner = new SocketSourceThread(source, "test1", "check");
+			runner.start();
+
+			// first connection: nothing
+			channel = server.accept();
+			channel.close();
+
+			// second connection: first string
+			channel = server.accept();
+			OutputStreamWriter writer = new OutputStreamWriter(channel.getOutputStream());
+			writer.write("test1\n");
+			writer.close();
+			channel.close();
+
+			// third connection: nothing
+			channel = server.accept();
+			channel.close();
 
-		assertEquals(10, source.socketSource.retries);
+			// forth connection: second string
+			channel = server.accept();
+			writer = new OutputStreamWriter(channel.getOutputStream());
+			writer.write("check\n");
+			writer.flush();
+
+			runner.waitForNumElements(2);
+			runner.cancel();
+			runner.waitUntilDone();
+		}
+		finally {
+			if (channel != null) {
+				IOUtils.closeQuietly(channel);
+			}
+			IOUtils.closeQuietly(server);
+		}
 	}
 
 	@Test
-	public void testSocketSourceNeverRetry() throws Exception{
-		error.set(null);
-		ServerSocket serverSo = new ServerSocket(0);
-		SocketSource source = new SocketSource(serverSo, 0);
-		source.start();
-
-		Socket channel;
-		channel = serverSo.accept();
-		channel.close();
-		serverSo.close();
-		sleep(2000);
-		source.cancel();
-
-		if (error.get() != null) {
-			Throwable t = error.get();
-			t.printStackTrace();
-			fail("Error in spawned thread: " + t.getMessage());
-		}
+	public void testSocketSourceOutputAcrossRetries() throws Exception {
+		ServerSocket server = new ServerSocket(0);
+		Socket channel = null;
+
+		try {
+			SocketTextStreamFunction source = new SocketTextStreamFunction(LOCALHOST, server.getLocalPort(), '\n', 10, 100);
+
+			SocketSourceThread runner = new SocketSourceThread(source, "test1", "check1", "check2");
+			runner.start();
+
+			// first connection: nothing
+			channel = server.accept();
+			channel.close();
+
+			// second connection: first string
+			channel = server.accept();
+			OutputStreamWriter writer = new OutputStreamWriter(channel.getOutputStream());
+			writer.write("te");
+			writer.close();
+			channel.close();
+
+			// third connection: nothing
+			channel = server.accept();
+			channel.close();
 
-		assertEquals(0, source.socketSource.retries);
+			// forth connection: second string
+			channel = server.accept();
+			writer = new OutputStreamWriter(channel.getOutputStream());
+			writer.write("st1\n");
+			writer.write("check1\n");
+			writer.write("check2\n");
+			writer.flush();
+
+			runner.waitForNumElements(2);
+			runner.cancel();
+			runner.waitUntilDone();
+		}
+		finally {
+			if (channel != null) {
+				IOUtils.closeQuietly(channel);
+			}
+			IOUtils.closeQuietly(server);
+		}
 	}
+	
+	// ------------------------------------------------------------------------
+
+	private static class SocketSourceThread extends Thread {
+		
+		private final Object sync = new Object();
+		
+		private final SocketTextStreamFunction socketSource;
+		
+		private final String[] expectedData;
+		
+		private volatile Throwable error;
+		private volatile int numElementsReceived;
+		private volatile boolean canceled;
+		private volatile boolean done;
+		
+		public SocketSourceThread(SocketTextStreamFunction socketSource, String... expectedData) {
+			this.socketSource = socketSource;
+			this.expectedData = expectedData;
+		}
 
-	@Test
-	public void testSocketSourceRetryTenTimesWithFirstPass() throws Exception{
-		ArgumentCaptor<String> argument = ArgumentCaptor.forClass(String.class);
-
-		error.set(null);
-		ServerSocket serverSo = new ServerSocket(0);
-		SocketSource source = new SocketSource(serverSo, 10);
-		source.socketSource.CONNECTION_RETRY_SLEEP = 200;
-
-		assertEquals(0, source.socketSource.retries);
-
-		source.start();
-
-		Socket channel;
-		channel = serverSo.accept();
-		DataOutputStream dataOutputStream = new DataOutputStream(channel.getOutputStream());
-		dataOutputStream.write("testFirstSocketpass\n".getBytes());
-		channel.close();
-		serverSo.close();
-		while(source.socketSource.retries < 10){
-			long lastRetry = source.socketSource.retries;
-			sleep(100);
-			assertTrue(source.socketSource.retries >= lastRetry);
-		};
-		assertEquals(10, source.socketSource.retries);
-		source.cancel();
-
-		verify(ctx).collect(argument.capture());
-
-		if (error.get() != null) {
-			Throwable t = error.get();
-			t.printStackTrace();
-			fail("Error in spawned thread: " + t.getMessage());
+		public void run() {
+			try {
+				SourceFunction.SourceContext<String> ctx = new SourceFunction.SourceContext<String>() {
+					
+					private final Object lock = new Object();
+					
+					@Override
+					public void collect(String element) {
+						int pos = numElementsReceived;
+						
+						// make sure waiter know of us
+						synchronized (sync) {
+							numElementsReceived++;
+							sync.notifyAll();
+						}
+						
+						if (expectedData != null && expectedData.length > pos) {
+							assertEquals(expectedData[pos], element);
+						}
+					}
+
+					@Override
+					public void collectWithTimestamp(String element, long timestamp) {
+						collect(element);
+					}
+
+					@Override
+					public void emitWatermark(Watermark mark) {}
+
+					@Override
+					public Object getCheckpointLock() {
+						return lock;
+					}
+
+					@Override
+					public void close() {}
+				};
+				
+				socketSource.run(ctx);
+			}
+			catch (Throwable t) {
+				synchronized (sync) {
+					if (!canceled) {
+						error = t;
+					}
+					sync.notifyAll();
+				}
+			}
+			finally {
+				synchronized (sync) {
+					done = true;
+					sync.notifyAll();
+				}
+			}
+		}
+		
+		public void cancel() {
+			synchronized (sync) {
+				canceled = true;
+				socketSource.cancel();
+				interrupt();
+			}
 		}
 
-		assertEquals("testFirstSocketpass", argument.getValue());
-		assertEquals(10, source.socketSource.retries);
+		public void waitForNumElements(int numElements) throws InterruptedException {
+			synchronized (sync) {
+				while (error == null && !canceled && !done && numElementsReceived < numElements) {
+					sync.wait();
+				}
+
+				if (error != null) {
+					throw new RuntimeException("Error in source thread", error);
+				}
+				if (canceled) {
+					throw new RuntimeException("canceled");
+				}
+				if (done) {
+					throw new RuntimeException("Exited cleanly before expected number of elements");
+				}
+			}
+		}
+
+		public void waitUntilDone() throws InterruptedException {
+			join();
+
+			if (error != null) {
+				throw new RuntimeException("Error in source thread", error);
+			}
+		}
 	}
 }
\ No newline at end of file