Skip to content

Commit

Permalink
Merge pull request #517 from komamitsu/disconn-before-ack
Browse files Browse the repository at this point in the history
Improve error handling of disconnection when receiving response in SSLSender
  • Loading branch information
komamitsu authored Jul 3, 2022
2 parents d86ae20 + 16295c9 commit 04b4fae
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,24 @@

import org.komamitsu.fluency.fluentd.ingester.sender.failuredetect.FailureDetector;
import org.komamitsu.fluency.validation.Validatable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;

public class SSLSender
extends NetworkSender<SSLSocket>
{
private static final Logger LOG = LoggerFactory.getLogger(TCPSender.class);
private final AtomicReference<SSLSocket> socket = new AtomicReference<>();
private final SSLSocketBuilder socketBuilder;
private final Config config;
Expand Down Expand Up @@ -99,6 +103,9 @@ protected void recvResponse(SSLSocket sslSocket, ByteBuffer buffer)
InputStream inputStream = sslSocket.getInputStream();
byte[] tempBuf = new byte[buffer.remaining()];
int read = inputStream.read(tempBuf);
if (read < 0) {
throw new SocketException("The connection is disconnected by the peer");
}
buffer.put(tempBuf, 0, read);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.util.List;
Expand Down Expand Up @@ -69,7 +70,7 @@ protected SocketChannel getOrCreateSocketInternal()
socketChannel.socket().setSoTimeout(config.getReadTimeoutMilli());
}
catch (Throwable e) {
// In case of java.net.UnknownHostException and so on, the internal socket can be leak.
// In case of java.net.UnknownHostException and so on, the internal socket can be leaked.
// So the SocketChannel should be closed here to avoid a socket leak.
socketChannel.close();
throw e;
Expand All @@ -90,7 +91,10 @@ protected void sendBuffers(SocketChannel socketChannel, List<ByteBuffer> buffers
protected void recvResponse(SocketChannel socketChannel, ByteBuffer buffer)
throws IOException
{
socketChannel.read(buffer);
int read = socketChannel.read(buffer);
if (read < 0) {
throw new SocketException("The connection is disconnected by the peer");
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import java.net.ServerSocket;
import java.net.Socket;
import java.nio.channels.ClosedByInterruptException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
Expand All @@ -42,6 +44,8 @@ public class MockTCPServer
private final AtomicInteger threadSeqNum = new AtomicInteger();
private ExecutorService executorService;
private ServerTask serverTask;
// TODO Make this class immutable
private List<Runnable> tasks = new ArrayList<>();

public MockTCPServer(boolean sslEnabled)
{
Expand Down Expand Up @@ -85,8 +89,9 @@ public Thread newThread(Runnable r)

if (serverTask == null) {
serverTask = new ServerTask(executorService, lastEventTimeStampMilli, getEventHandler(),
sslEnabled ? SSLTestSocketFactories.createServerSocket() : new ServerSocket());
sslEnabled ? SSLTestSocketFactories.createServerSocket() : new ServerSocket(), tasks);
executorService.execute(serverTask);
tasks.add(serverTask);
}

for (int i = 0; i < 10; i++) {
Expand Down Expand Up @@ -127,6 +132,12 @@ public int getLocalPort()

public synchronized void stop()
throws IOException
{
stop(false);
}

public synchronized void stop(boolean immediate)
throws IOException
{
if (executorService == null) {
return;
Expand All @@ -135,13 +146,26 @@ public synchronized void stop()
executorService.shutdown();
try {
if (!executorService.awaitTermination(1000, TimeUnit.MILLISECONDS)) {
LOG.debug("Shutting down MockTCPServer and child tasks... {}", this);
executorService.shutdownNow();
}
}
catch (InterruptedException e) {
LOG.warn("ExecutorService.shutdown() was failed: {}", this, e);
Thread.currentThread().interrupt();
}

if (immediate) {
LOG.debug("Closing related sockets {}", this);
for (Runnable runnable : tasks) {
if (runnable instanceof ServerTask) {
((ServerTask) runnable).close();
} else if (runnable instanceof ServerTask.AcceptTask) {
((ServerTask.AcceptTask) runnable).close();
}
}
}

executorService = null;
serverTask = null;
}
Expand All @@ -162,12 +186,15 @@ private static class ServerTask
private final ExecutorService serverExecutorService;
private final EventHandler eventHandler;
private final AtomicLong lastEventTimeStampMilli;
private final List<Runnable> tasks;

private ServerTask(
ExecutorService executorService,
AtomicLong lastEventTimeStampMilli,
EventHandler eventHandler,
ServerSocket serverSocket)
ServerSocket serverSocket,
List<Runnable> tasks
)
throws IOException
{
this.serverExecutorService = executorService;
Expand All @@ -177,6 +204,7 @@ private ServerTask(
if (!serverSocket.isBound()) {
serverSocket.bind(null);
}
this.tasks = tasks;
}

public int getLocalPort()
Expand All @@ -200,8 +228,9 @@ public void run()
LOG.debug("ServerTask: accepting... this={}, local.port={}", this, getLocalPort());
Socket acceptSocket = serverSocket.accept();
LOG.debug("ServerTask: accepted. this={}, local.port={}, remote.port={}", this, getLocalPort(), acceptSocket.getPort());
serverExecutorService.execute(
new AcceptTask(serverExecutorService, lastEventTimeStampMilli, eventHandler, acceptSocket));
AcceptTask acceptTask = new AcceptTask(serverExecutorService, lastEventTimeStampMilli, eventHandler, acceptSocket);
serverExecutorService.execute(acceptTask);
tasks.add(acceptTask);
}
catch (RejectedExecutionException e) {
LOG.debug("ServerTask: ServerSocket.accept() failed[{}]: this={}", e.getMessage(), this);
Expand All @@ -222,6 +251,12 @@ public void run()
LOG.info("ServerTask: Finishing ServerTask...: this={}", this);
}

private void close()
throws IOException
{
serverSocket.close();
}

private static class AcceptTask
implements Runnable
{
Expand All @@ -238,6 +273,12 @@ private AcceptTask(ExecutorService serverExecutorService, AtomicLong lastEventTi
this.acceptSocket = acceptSocket;
}

private void close()
throws IOException
{
acceptSocket.close();
}

@Override
public void run()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,14 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.Locale;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;

import static org.hamcrest.MatcherAssert.assertThat;
Expand Down Expand Up @@ -212,6 +210,55 @@ void testReadTimeout()
}
}

private Throwable extractRootCause(Throwable exception)
{
Throwable e = exception;
while (e.getCause() != null) {
e = e.getCause();
}
return e;
}

@Test
void testDisconnBeforeRecv()
throws Exception
{
final MockTCPServer server = new MockTCPServer(true);
server.start();

try {
final CountDownLatch latch = new CountDownLatch(1);
ExecutorService executorService = Executors.newSingleThreadExecutor();
executorService.execute(() -> {
SSLSender.Config senderConfig = new SSLSender.Config();
senderConfig.setPort(server.getLocalPort());
senderConfig.setReadTimeoutMilli(4000);
senderConfig.setSslSocketFactory(SSL_CLIENT_SOCKET_FACTORY);
SSLSender sender = new SSLSender(senderConfig);
try {
sender.sendWithAck(Arrays.asList(ByteBuffer.wrap("hello, world".getBytes(StandardCharsets.UTF_8))), "Waiting ack forever");
}
catch (Throwable e) {
Throwable rootCause = extractRootCause(e);
if (rootCause instanceof SocketException && rootCause.getMessage().toLowerCase().contains("disconnected")) {
latch.countDown();
}
else {
throw new RuntimeException(e);
}
}
});

TimeUnit.MILLISECONDS.sleep(1000);
server.stop(true);

assertTrue(latch.await(8000, TimeUnit.MILLISECONDS));
}
finally {
server.stop();
}
}

@Test
void testClose()
throws Exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
Expand All @@ -46,6 +47,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.komamitsu.fluency.fluentd.SSLTestSocketFactories.SSL_CLIENT_SOCKET_FACTORY;

class TCPSenderTest
{
Expand Down Expand Up @@ -206,6 +208,54 @@ void testReadTimeout()
}
}

private Throwable extractRootCause(Throwable exception)
{
Throwable e = exception;
while (e.getCause() != null) {
e = e.getCause();
}
return e;
}

@Test
void testDisconnBeforeRecv()
throws Exception
{
final MockTCPServer server = new MockTCPServer(false);
server.start();

try {
final CountDownLatch latch = new CountDownLatch(1);
ExecutorService executorService = Executors.newSingleThreadExecutor();
executorService.execute(() -> {
TCPSender.Config senderConfig = new TCPSender.Config();
senderConfig.setPort(server.getLocalPort());
senderConfig.setReadTimeoutMilli(4000);
TCPSender sender = new TCPSender(senderConfig);
try {
sender.sendWithAck(Arrays.asList(ByteBuffer.wrap("hello, world".getBytes(StandardCharsets.UTF_8))), "Waiting ack forever");
}
catch (Throwable e) {
Throwable rootCause = extractRootCause(e);
if (rootCause instanceof SocketException && rootCause.getMessage().toLowerCase().contains("disconnected")) {
latch.countDown();
}
else {
throw new RuntimeException(e);
}
}
});

TimeUnit.MILLISECONDS.sleep(1000);
server.stop(true);

assertTrue(latch.await(8000, TimeUnit.MILLISECONDS));
}
finally {
server.stop();
}
}

@Test
void testClose()
throws Exception
Expand Down

0 comments on commit 04b4fae

Please sign in to comment.