Skip to content

Commit

Permalink
AME-11744 Close WebSocket connections if no pong received after 60 se…
Browse files Browse the repository at this point in the history
…conds
  • Loading branch information
joebandenburg committed Nov 4, 2016
1 parent bdaf347 commit 670bfc3
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.forgerock.openam.notifications.integration;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;

import javax.inject.Inject;
import javax.inject.Named;
Expand Down Expand Up @@ -67,6 +68,15 @@ ExecutorService executorService(ExecutorServiceFactory factory) {
return factory.createFixedThreadPool(1);
}

@Provides
@Inject
@Exposed
@Singleton
@Named("webSocketScheduledExecutorService")
ScheduledExecutorService scheduledExecutorService(ExecutorServiceFactory factory) {
return factory.createScheduledService(5);
}

@Provides
@Exposed
@Inject
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,21 @@
import static org.forgerock.json.JsonValue.object;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;

import javax.inject.Inject;
import javax.inject.Named;
import javax.websocket.DecodeException;
import javax.websocket.EncodeException;
import javax.websocket.EndpointConfig;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.PongMessage;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;

Expand All @@ -38,6 +45,7 @@
import org.forgerock.openam.notifications.Subscription;
import org.forgerock.openam.notifications.Topic;
import org.forgerock.util.Reject;
import org.forgerock.util.time.TimeService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -64,16 +72,23 @@
public final class NotificationsWebSocket {

private static final Logger logger = LoggerFactory.getLogger(NotificationsWebSocket.class);
private static final long TIMEOUT_MILLISECONDS = 1000 * 60;

private final NotificationBroker broker;
private final TimeService timeService;
private final ScheduledExecutorService executorService;
private Subscription subscription;
private long lastMessageTime;
private ScheduledFuture<?> pingFuture;

/**
* No args constructor as required by JSR-356. Instances created
* with this constructor are not expected to be used.
*/
public NotificationsWebSocket() {
broker = null;
timeService = null;
executorService = null;
}

/**
Expand All @@ -82,9 +97,12 @@ public NotificationsWebSocket() {
* @param broker the notification broker
*/
@Inject
public NotificationsWebSocket(NotificationBroker broker) {
public NotificationsWebSocket(NotificationBroker broker, TimeService timeService,
@Named("webSocketScheduledExecutorService") ScheduledExecutorService executorService) {
Reject.ifNull(broker, "Broker must not be null");
this.broker = broker;
this.timeService = timeService;
this.executorService = executorService;
}

/**
Expand All @@ -93,9 +111,23 @@ public NotificationsWebSocket(NotificationBroker broker) {
* @param session the websocket session
*/
@OnOpen
public void open(Session session) {
public void open(final Session session) {
Reject.ifNull(session, "Session must not be null");
subscription = broker.subscribe(new WebSocketConsumer(session));
session.setMaxIdleTimeout(TIMEOUT_MILLISECONDS);
lastMessageTime = timeService.now();
pingFuture = executorService.scheduleAtFixedRate(new Runnable() {
@Override
public void run() {
try {
if (session.isOpen()) {
session.getAsyncRemote().sendPing(ByteBuffer.wrap("ping".getBytes()));
}
} catch (IOException e) {
logger.info("Failed to send ping to client", e);
}
}
}, TIMEOUT_MILLISECONDS / 2, TIMEOUT_MILLISECONDS / 2, TimeUnit.MILLISECONDS);
}

/**
Expand All @@ -104,10 +136,13 @@ public void open(Session session) {
@OnClose
public void close() {
subscription.close();
if (pingFuture != null) {
pingFuture.cancel(true);
}
}

/**
* See {@link OnMessage}.
* Call when the server receives a normal message.
*
* @param session the websocket session
* @param json the json message
Expand All @@ -117,6 +152,8 @@ public void message(Session session, JsonValue json) {
Reject.ifNull(session, "Session must not be null");
Reject.ifNull(json, "Json must not be null");

lastMessageTime = timeService.now();

if (json.isDefined("id") && !json.get("id").isString()) {
sendError(session, null, "\"id\" must be a string");
return;
Expand Down Expand Up @@ -154,6 +191,16 @@ public void message(Session session, JsonValue json) {
sendMessage(session, id, topic, "subscription registered");
}

/**
* Called when the server receives a pong.
*
* @param message Unused, but required to register the correct listener.
*/
@OnMessage
public void pong(PongMessage message) {
lastMessageTime = timeService.now();
}

/**
* See {@link javax.websocket.Endpoint#onError(Session, Throwable)}.
*
Expand All @@ -162,7 +209,11 @@ public void message(Session session, JsonValue json) {
*/
@OnError
public void error(Session session, Throwable error) {
sendError(session, null, error.getMessage());
if (error instanceof DecodeException) {
sendError(session, null, error.getMessage());
} else {
logger.info("WebSocket error", error);
}
}

private void sendMessage(Session session, String id, String topic, String message) {
Expand All @@ -177,7 +228,7 @@ private void sendMessage(Session session, String id, String topic, String messag

session.getBasicRemote().sendObject(json);
} catch (IOException | EncodeException e) {
logger.error("Unable to send message to client Message was \"" + message + "\"", e);
logger.warn("Unable to send message to client. Message was \"" + message + "\"", e);
}
}

Expand All @@ -191,11 +242,11 @@ private void sendError(Session session, String id, String message) {

session.getBasicRemote().sendObject(json);
} catch (IOException | EncodeException e) {
logger.error("Unable to send error to client. Error was \"" + message + "\"", e);
logger.warn("Unable to send error message to client. Error was \"" + message + "\"", e);
}
}

private static final class WebSocketConsumer implements Consumer {
private final class WebSocketConsumer implements Consumer {

private final Session session;

Expand All @@ -207,10 +258,17 @@ private WebSocketConsumer(Session session) {
public void accept(JsonValue notification) {
Reject.ifNull(notification);

try {
session.getBasicRemote().sendObject(notification);
} catch (IOException | EncodeException e) {
logger.error("Unable to send notification", e);
if (timeService.since(lastMessageTime) > TIMEOUT_MILLISECONDS) {
try {
session.close();
} catch (IOException e) {
logger.warn("Failed to close WebSocket connection", e);
}
return;
}

if (session.isOpen()) {
session.getAsyncRemote().sendObject(notification);
}
}

Expand Down
Loading

0 comments on commit 670bfc3

Please sign in to comment.