diff --git a/broker/src/main/java/io/moquette/BrokerConstants.java b/broker/src/main/java/io/moquette/BrokerConstants.java index 086d032c7..d0e3e92c4 100644 --- a/broker/src/main/java/io/moquette/BrokerConstants.java +++ b/broker/src/main/java/io/moquette/BrokerConstants.java @@ -22,6 +22,7 @@ public final class BrokerConstants { public static final String INTERCEPT_HANDLER_PROPERTY_NAME = "intercept.handler"; public static final String BROKER_INTERCEPTOR_THREAD_POOL_SIZE = "intercept.thread_pool.size"; + public static final String AUTH_THREAD_POOL_SIZE = "auth.intercept.thread_pool.size"; public static final String PERSISTENT_STORE_PROPERTY_NAME = "persistent_store"; public static final String AUTOSAVE_INTERVAL_PROPERTY_NAME = "autosave_interval"; public static final String PASSWORD_FILE_PROPERTY_NAME = "password_file"; diff --git a/broker/src/main/java/io/moquette/broker/MQTTConnection.java b/broker/src/main/java/io/moquette/broker/MQTTConnection.java index e82b99194..9f176eaaf 100644 --- a/broker/src/main/java/io/moquette/broker/MQTTConnection.java +++ b/broker/src/main/java/io/moquette/broker/MQTTConnection.java @@ -15,8 +15,8 @@ */ package io.moquette.broker; -import io.moquette.broker.subscriptions.Topic; import io.moquette.broker.security.IAuthenticator; +import io.moquette.broker.subscriptions.Topic; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; @@ -28,7 +28,9 @@ import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; -import java.util.*; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -36,7 +38,8 @@ import static io.netty.channel.ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE; import static io.netty.handler.codec.mqtt.MqttConnectReturnCode.*; import static io.netty.handler.codec.mqtt.MqttMessageIdVariableHeader.from; -import static io.netty.handler.codec.mqtt.MqttQoS.*; +import static io.netty.handler.codec.mqtt.MqttQoS.AT_LEAST_ONCE; +import static io.netty.handler.codec.mqtt.MqttQoS.AT_MOST_ONCE; final class MQTTConnection { @@ -158,27 +161,36 @@ void processConnect(MqttConnectMessage msg) { username, channel); } - if (!login(msg, clientId)) { - abortConnection(CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD); - channel.close().addListener(CLOSE_ON_FAILURE); - return; - } - - try { - LOG.trace("Binding MQTTConnection (channel: {}) to session", channel); - sessionRegistry.bindToSession(this, msg, clientId); - - initializeKeepAliveTimeout(channel, msg, clientId); - setupInflightResender(channel); - - NettyUtils.clientID(channel, clientId); - LOG.trace("CONNACK sent, channel: {}", channel); - postOffice.dispatchConnection(msg); - LOG.trace("dispatch connection: {}", msg.toString()); - } catch (SessionCorruptedException scex) { - LOG.warn("MQTT session for client ID {} cannot be created, channel: {}", clientId, channel); - abortConnection(CONNECTION_REFUSED_SERVER_UNAVAILABLE); - } + final String newClientId = clientId; + CompletableFuture future = login(msg, newClientId); + future.whenComplete((status, t) -> { + if (t == null) { + if (status) { + try { + LOG.trace("Binding MQTTConnection (channel: {}) to session", channel); + sessionRegistry.bindToSession(this, msg, newClientId); + + initializeKeepAliveTimeout(channel, msg, newClientId); + setupInflightResender(channel); + + NettyUtils.clientID(channel, newClientId); + LOG.trace("CONNACK sent, channel: {}", channel); + postOffice.dispatchConnection(msg); + LOG.trace("dispatch connection: {}", msg.toString()); + } catch (SessionCorruptedException scex) { + LOG.warn("MQTT session for client ID {} cannot be created, channel: {}", newClientId, channel); + abortConnection(CONNECTION_REFUSED_SERVER_UNAVAILABLE); + } + } else { + abortConnection(CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD); + channel.close().addListener(CLOSE_ON_FAILURE); + } + } else { + LOG.warn("MQTT connection for client ID {} cannot be created, channel: {}. Error message {}", newClientId, channel, t.getMessage()); + abortConnection(CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD); + channel.close().addListener(CLOSE_ON_FAILURE); + } + }); } private void setupInflightResender(Channel channel) { @@ -222,7 +234,7 @@ private MqttConnAckMessage connAck(MqttConnectReturnCode returnCode, boolean ses return new MqttConnAckMessage(mqttFixedHeader, mqttConnAckVariableHeader); } - private boolean login(MqttConnectMessage msg, final String clientId) { + private CompletableFuture login(MqttConnectMessage msg, final String clientId) { // handle user authentication if (msg.variableHeader().hasUserName()) { byte[] pwd = null; @@ -230,19 +242,27 @@ private boolean login(MqttConnectMessage msg, final String clientId) { pwd = msg.payload().password().getBytes(StandardCharsets.UTF_8); } else if (!brokerConfig.isAllowAnonymous()) { LOG.error("Client didn't supply any password and MQTT anonymous mode is disabled CId={}", clientId); - return false; + return CompletableFuture.completedFuture(false); } final String login = msg.payload().userName(); - if (!authenticator.checkValid(clientId, login, pwd)) { - LOG.error("Authenticator has rejected the MQTT credentials CId={}, username={}", clientId, login); - return false; - } - NettyUtils.userName(channel, login); + + return authenticator.checkValid(clientId, login, pwd).handleAsync((status, t) -> { + if (t == null) { + if (status) NettyUtils.userName(channel, login); + else + LOG.error("Authenticator has rejected the MQTT credentials CId={}, username={}", clientId, login); + + return status; + } else { + LOG.error("Authenticator has rejected the MQTT credentials CId={}, username={}. Error message: {}", clientId, login, t.getMessage()); + return false; + } + }); } else if (!brokerConfig.isAllowAnonymous()) { LOG.error("Client didn't supply any credentials and MQTT anonymous mode is disabled. CId={}", clientId); - return false; + return CompletableFuture.completedFuture(false); } - return true; + return CompletableFuture.completedFuture(true); } void handleConnectionLost() { diff --git a/broker/src/main/java/io/moquette/broker/PostOffice.java b/broker/src/main/java/io/moquette/broker/PostOffice.java index f416666bf..61ff33a65 100644 --- a/broker/src/main/java/io/moquette/broker/PostOffice.java +++ b/broker/src/main/java/io/moquette/broker/PostOffice.java @@ -310,7 +310,9 @@ void dispatchConnectionLost(String clientId,String userName){ } void flushInFlight(MQTTConnection mqttConnection) { - Session targetSession = sessionRegistry.retrieve(mqttConnection.getClientId()); - targetSession.flushAllQueuedMessages(); + if(mqttConnection.getClientId() != null) { + Session targetSession = sessionRegistry.retrieve(mqttConnection.getClientId()); + targetSession.flushAllQueuedMessages(); + } } } diff --git a/broker/src/main/java/io/moquette/broker/Server.java b/broker/src/main/java/io/moquette/broker/Server.java index 874855847..75c19890f 100644 --- a/broker/src/main/java/io/moquette/broker/Server.java +++ b/broker/src/main/java/io/moquette/broker/Server.java @@ -232,7 +232,7 @@ private IAuthenticator initializeAuthenticator(IAuthenticator authenticator, ICo if (passwdPath.isEmpty()) { authenticator = new AcceptAllAuthenticator(); } else { - authenticator = new ResourceAuthenticator(resourceLoader, passwdPath); + authenticator = new ResourceAuthenticator(resourceLoader, passwdPath, props); } LOG.info("An {} authenticator instance will be used", authenticator.getClass().getName()); } diff --git a/broker/src/main/java/io/moquette/broker/Utils.java b/broker/src/main/java/io/moquette/broker/Utils.java index 8c178e117..220b0a006 100644 --- a/broker/src/main/java/io/moquette/broker/Utils.java +++ b/broker/src/main/java/io/moquette/broker/Utils.java @@ -20,6 +20,8 @@ import io.netty.handler.codec.mqtt.MqttMessage; import io.netty.handler.codec.mqtt.MqttMessageIdVariableHeader; import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; /** * Utility static methods, like Map get with default value, or elvis operator. @@ -46,6 +48,24 @@ public static byte[] readBytesAndRewind(ByteBuf payload) { return payloadContent; } + public static void shutdownAndAwaitTermination(ExecutorService pool) { + pool.shutdown(); // Disable new tasks from being submitted + try { + // Wait a while for existing tasks to terminate + if (!pool.awaitTermination(30, TimeUnit.SECONDS)) { + pool.shutdownNow(); // Cancel currently executing tasks + // Wait a while for tasks to respond to being cancelled + if (!pool.awaitTermination(30, TimeUnit.SECONDS)) + System.err.println("Pool did not terminate"); + } + } catch (InterruptedException ie) { + // (Re-)Cancel if current thread also interrupted + pool.shutdownNow(); + // Preserve interrupt status + Thread.currentThread().interrupt(); + } + } + private Utils() { } } diff --git a/broker/src/main/java/io/moquette/broker/security/AcceptAllAuthenticator.java b/broker/src/main/java/io/moquette/broker/security/AcceptAllAuthenticator.java index 02d820419..3596d07bd 100644 --- a/broker/src/main/java/io/moquette/broker/security/AcceptAllAuthenticator.java +++ b/broker/src/main/java/io/moquette/broker/security/AcceptAllAuthenticator.java @@ -16,10 +16,12 @@ package io.moquette.broker.security; +import java.util.concurrent.CompletableFuture; + public class AcceptAllAuthenticator implements IAuthenticator { @Override - public boolean checkValid(String clientId, String username, byte[] password) { - return true; + public CompletableFuture checkValid(String clientId, String username, byte[] password) { + return CompletableFuture.completedFuture(true); } } diff --git a/broker/src/main/java/io/moquette/broker/security/DBAuthenticator.java b/broker/src/main/java/io/moquette/broker/security/DBAuthenticator.java index e23fc0f29..28dcc513a 100644 --- a/broker/src/main/java/io/moquette/broker/security/DBAuthenticator.java +++ b/broker/src/main/java/io/moquette/broker/security/DBAuthenticator.java @@ -18,6 +18,7 @@ import com.zaxxer.hikari.HikariDataSource; import io.moquette.BrokerConstants; +import io.moquette.broker.Utils; import io.moquette.broker.config.IConfig; import org.apache.commons.codec.binary.Hex; import org.slf4j.Logger; @@ -29,6 +30,9 @@ import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * Load user credentials from a SQL database. sql driver must be provided at runtime @@ -38,6 +42,7 @@ public class DBAuthenticator implements IAuthenticator { private static final Logger LOG = LoggerFactory.getLogger(DBAuthenticator.class); private final MessageDigest messageDigest; + private final ExecutorService executor; private HikariDataSource dataSource; private String sqlQuery; @@ -45,7 +50,8 @@ public DBAuthenticator(IConfig conf) { this(conf.getProperty(BrokerConstants.DB_AUTHENTICATOR_DRIVER, ""), conf.getProperty(BrokerConstants.DB_AUTHENTICATOR_URL, ""), conf.getProperty(BrokerConstants.DB_AUTHENTICATOR_QUERY, ""), - conf.getProperty(BrokerConstants.DB_AUTHENTICATOR_DIGEST, "")); + conf.getProperty(BrokerConstants.DB_AUTHENTICATOR_DIGEST, ""), + Integer.parseInt(conf.getProperty(BrokerConstants.AUTH_THREAD_POOL_SIZE, "1"))); } /** @@ -59,11 +65,14 @@ public DBAuthenticator(IConfig conf) { * : sql query like : "SELECT PASSWORD FROM USER WHERE LOGIN=?" * @param digestMethod * : password encoding algorithm : "MD5", "SHA-1", "SHA-256" + * @param authExecutorPoolSize + * : auth executor pool size. Defaults to 1. */ - public DBAuthenticator(String driver, String jdbcUrl, String sqlQuery, String digestMethod) { + public DBAuthenticator(String driver, String jdbcUrl, String sqlQuery, String digestMethod, int authExecutorPoolSize) { this.sqlQuery = sqlQuery; this.dataSource = new HikariDataSource(); this.dataSource.setJdbcUrl(jdbcUrl); + this.executor = Executors.newFixedThreadPool(authExecutorPoolSize); try { this.messageDigest = MessageDigest.getInstance(digestMethod); @@ -74,46 +83,52 @@ public DBAuthenticator(String driver, String jdbcUrl, String sqlQuery, String di } @Override - public synchronized boolean checkValid(String clientId, String username, byte[] password) { - // Check Username / Password in DB using sqlQuery - if (username == null || password == null) { - LOG.info("username or password was null"); - return false; - } - - ResultSet resultSet = null; - PreparedStatement preparedStatement = null; - Connection conn = null; - try { - conn = this.dataSource.getConnection(); - - preparedStatement = conn.prepareStatement(this.sqlQuery); - preparedStatement.setString(1, username); - resultSet = preparedStatement.executeQuery(); - if (resultSet.next()) { - final String foundPwq = resultSet.getString(1); - messageDigest.update(password); - byte[] digest = messageDigest.digest(); - String encodedPasswd = new String(Hex.encodeHex(digest)); - return foundPwq.equals(encodedPasswd); + public synchronized CompletableFuture checkValid(String clientId, String username, byte[] password) { + return CompletableFuture.supplyAsync(() -> { + // Check Username / Password in DB using sqlQuery + if (username == null || password == null) { + LOG.info("username or password was null"); + return false; } - } catch (SQLException sqlex) { - LOG.error("Error quering DB for username: {}", username, sqlex); - } finally { + + ResultSet resultSet = null; + PreparedStatement preparedStatement = null; + Connection conn = null; try { - if (resultSet != null) { - resultSet.close(); - } - if (preparedStatement != null) { - preparedStatement.close(); + conn = this.dataSource.getConnection(); + + preparedStatement = conn.prepareStatement(this.sqlQuery); + preparedStatement.setString(1, username); + resultSet = preparedStatement.executeQuery(); + if (resultSet.next()) { + final String foundPwq = resultSet.getString(1); + messageDigest.update(password); + byte[] digest = messageDigest.digest(); + String encodedPasswd = new String(Hex.encodeHex(digest)); + return foundPwq.equals(encodedPasswd); } - if (conn != null) { - conn.close(); + } catch (SQLException sqlex) { + LOG.error("Error quering DB for username: {}", username, sqlex); + } finally { + try { + if (resultSet != null) { + resultSet.close(); + } + if (preparedStatement != null) { + preparedStatement.close(); + } + if (conn != null) { + conn.close(); + } + } catch (SQLException e) { + LOG.error("Error releasing connection to the datasource", username, e); } - } catch (SQLException e) { - LOG.error("Error releasing connection to the datasource", username, e); } - } - return false; + return false; + }, executor); + } + + public void cleanup() { + Utils.shutdownAndAwaitTermination(executor); } } diff --git a/broker/src/main/java/io/moquette/broker/security/FileAuthenticator.java b/broker/src/main/java/io/moquette/broker/security/FileAuthenticator.java index f33ec5f3f..8390f212f 100644 --- a/broker/src/main/java/io/moquette/broker/security/FileAuthenticator.java +++ b/broker/src/main/java/io/moquette/broker/security/FileAuthenticator.java @@ -17,6 +17,7 @@ package io.moquette.broker.security; import io.moquette.broker.config.FileResourceLoader; +import io.moquette.broker.config.IConfig; /** * Load user credentials from a text file. Each line of the file is formatted as @@ -35,7 +36,7 @@ */ public class FileAuthenticator extends ResourceAuthenticator { - public FileAuthenticator(String parent, String filePath) { - super(new FileResourceLoader(parent), filePath); + public FileAuthenticator(String parent, String filePath, IConfig config) { + super(new FileResourceLoader(parent), filePath, config); } } diff --git a/broker/src/main/java/io/moquette/broker/security/IAuthenticator.java b/broker/src/main/java/io/moquette/broker/security/IAuthenticator.java index 1b3b183f1..0a9676986 100644 --- a/broker/src/main/java/io/moquette/broker/security/IAuthenticator.java +++ b/broker/src/main/java/io/moquette/broker/security/IAuthenticator.java @@ -16,10 +16,12 @@ package io.moquette.broker.security; +import java.util.concurrent.CompletableFuture; + /** * username and password checker */ public interface IAuthenticator { - boolean checkValid(String clientId, String username, byte[] password); + CompletableFuture checkValid(String clientId, String username, byte[] password); } diff --git a/broker/src/main/java/io/moquette/broker/security/ResourceAuthenticator.java b/broker/src/main/java/io/moquette/broker/security/ResourceAuthenticator.java index 7a102210b..00ac812d7 100644 --- a/broker/src/main/java/io/moquette/broker/security/ResourceAuthenticator.java +++ b/broker/src/main/java/io/moquette/broker/security/ResourceAuthenticator.java @@ -16,6 +16,8 @@ package io.moquette.broker.security; +import io.moquette.BrokerConstants; +import io.moquette.broker.config.IConfig; import io.moquette.broker.config.IResourceLoader; import org.apache.commons.codec.digest.DigestUtils; import org.slf4j.Logger; @@ -28,6 +30,9 @@ import java.text.ParseException; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * Load user credentials from a text resource. Each line of the file is formatted as @@ -48,7 +53,9 @@ public class ResourceAuthenticator implements IAuthenticator { private Map m_identities = new HashMap<>(); - public ResourceAuthenticator(IResourceLoader resourceLoader, String resourceName) { + private final ExecutorService executor; + + public ResourceAuthenticator(IResourceLoader resourceLoader, String resourceName, IConfig conf) { try { MessageDigest.getInstance("SHA-256"); } catch (NoSuchAlgorithmException nsaex) { @@ -56,6 +63,9 @@ public ResourceAuthenticator(IResourceLoader resourceLoader, String resourceName throw new RuntimeException(nsaex); } + int authExecutorPoolSize = Integer.parseInt(conf.getProperty(BrokerConstants.AUTH_THREAD_POOL_SIZE, "1")); + this.executor = Executors.newFixedThreadPool(authExecutorPoolSize); + LOG.info(String.format("Loading password %s %s", resourceLoader.getName(), resourceName)); Reader reader = null; try { @@ -112,17 +122,18 @@ private void parse(Reader reader) throws ParseException { } @Override - public boolean checkValid(String clientId, String username, byte[] password) { - if (username == null || password == null) { - LOG.info("username or password was null"); - return false; - } - String foundPwq = m_identities.get(username); - if (foundPwq == null) { - return false; - } - String encodedPasswd = DigestUtils.sha256Hex(password); - return foundPwq.equals(encodedPasswd); + public CompletableFuture checkValid(String clientId, String username, byte[] password) { + return CompletableFuture.supplyAsync(() -> { + if (username == null || password == null) { + LOG.info("username or password was null"); + return false; + } + String foundPwq = m_identities.get(username); + if (foundPwq == null) { + return false; + } + String encodedPasswd = DigestUtils.sha256Hex(password); + return foundPwq.equals(encodedPasswd); + }, executor); } - } diff --git a/broker/src/test/java/io/moquette/broker/MQTTConnectionConnectTest.java b/broker/src/test/java/io/moquette/broker/MQTTConnectionConnectTest.java index 4ab92f275..3e68a7303 100644 --- a/broker/src/test/java/io/moquette/broker/MQTTConnectionConnectTest.java +++ b/broker/src/test/java/io/moquette/broker/MQTTConnectionConnectTest.java @@ -22,12 +22,16 @@ import io.moquette.persistence.MemorySubscriptionsRepository; import io.netty.channel.Channel; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.mqtt.MqttConnAckMessage; import io.netty.handler.codec.mqtt.MqttConnectMessage; import io.netty.handler.codec.mqtt.MqttMessageBuilders; import io.netty.handler.codec.mqtt.MqttVersion; import org.junit.Before; import org.junit.Test; +import java.util.Objects; +import java.util.Optional; + import static io.moquette.broker.NettyChannelAssertions.assertEqualsConnAck; import static io.netty.handler.codec.mqtt.MqttConnectReturnCode.*; import static java.util.Collections.singleton; @@ -116,7 +120,8 @@ public void invalidAuthentication() { sut.processConnect(msg); // Verify - assertEqualsConnAck(CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD, channel.readOutbound()); + Optional connAckMessage = NettyChannelAssertions.retry(() -> channel.readOutbound(), Objects::nonNull, 5); + assertEqualsConnAck(CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD, connAckMessage.orElse(null)); assertFalse("Connection should be closed by the broker.", channel.isOpen()); } @@ -184,7 +189,8 @@ public void validAuthentication() { sut.processConnect(msg); // Verify - assertEqualsConnAck(CONNECTION_ACCEPTED, channel.readOutbound()); + Optional connAckMessage = NettyChannelAssertions.retry(() -> channel.readOutbound(), Objects::nonNull, 5); + assertEqualsConnAck(CONNECTION_ACCEPTED, connAckMessage.orElse(null)); assertTrue("Connection is accepted and therefore must remain open", channel.isOpen()); } @@ -198,7 +204,8 @@ public void noPasswdAuthentication() { sut.processConnect(msg); // Verify - assertEqualsConnAck(CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD, channel.readOutbound()); + Optional connAckMessage = NettyChannelAssertions.retry(() -> channel.readOutbound(), Objects::nonNull, 5); + assertEqualsConnAck(CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD, connAckMessage.orElse(null)); assertFalse("Connection must be closed by the broker", channel.isOpen()); } @@ -231,7 +238,8 @@ public void prohibitAnonymousClient_providingUsername() { sut.processConnect(msg); // Verify - assertEqualsConnAck(CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD, channel.readOutbound()); + Optional connAckMessage = NettyChannelAssertions.retry(() -> channel.readOutbound(), Objects::nonNull, 5); + assertEqualsConnAck(CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD, connAckMessage.orElse(null)); assertFalse("Connection should be closed by the broker.", channel.isOpen()); } @@ -250,7 +258,7 @@ public void testZeroByteClientIdNotAllowed() { sut.processConnect(msg); assertEqualsConnAck("Zero byte client identifiers are not allowed", - CONNECTION_REFUSED_IDENTIFIER_REJECTED, channel.readOutbound()); + CONNECTION_REFUSED_IDENTIFIER_REJECTED, channel.readOutbound()); assertFalse("Connection must closed", channel.isOpen()); } @@ -263,7 +271,7 @@ public void testZeroByteClientIdWithoutCleanSession() { sut.processConnect(msg); assertEqualsConnAck("Identifier must be rejected due to having clean session set to false", - CONNECTION_REFUSED_IDENTIFIER_REJECTED, channel.readOutbound()); + CONNECTION_REFUSED_IDENTIFIER_REJECTED, channel.readOutbound()); assertFalse("Connection must be closed by the broker", channel.isOpen()); } @@ -275,7 +283,9 @@ public void testBindWithSameClientIDBadCredentialsDoesntDropExistingClient() { .password(TEST_PWD) .build(); sut.processConnect(msg); - assertEqualsConnAck(CONNECTION_ACCEPTED, channel.readOutbound()); + + Optional connAckMessage = NettyChannelAssertions.retry(() -> channel.readOutbound(), Objects::nonNull, 5); + assertEqualsConnAck(CONNECTION_ACCEPTED, connAckMessage.orElse(null)); // create another connect same clientID but with bad credentials MqttConnectMessage evilClientConnMsg = MqttMessageBuilders.connect() @@ -294,7 +304,8 @@ public void testBindWithSameClientIDBadCredentialsDoesntDropExistingClient() { // Verify // the evil client gets a not auth notification - assertEqualsConnAck(CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD, evilChannel.readOutbound()); + Optional connAckMessage2 = NettyChannelAssertions.retry(evilChannel::readOutbound, Objects::nonNull, 5); + assertEqualsConnAck(CONNECTION_REFUSED_BAD_USER_NAME_OR_PASSWORD, connAckMessage2.orElse(null)); // the good client remains connected assertTrue("Original connected client must remain connected", channel.isOpen()); assertFalse("Channel trying to connect with bad credentials must be closed", evilChannel.isOpen()); @@ -307,13 +318,17 @@ public void testForceClientDisconnection_issue116() { .password(TEST_PWD) .build(); sut.processConnect(msg); - assertEqualsConnAck(CONNECTION_ACCEPTED, channel.readOutbound()); + + Optional connAckMessage = NettyChannelAssertions.retry(() -> channel.readOutbound(), Objects::nonNull, 5); + assertEqualsConnAck(CONNECTION_ACCEPTED, connAckMessage.orElse(null)); // now create another connection and check the new one closes the older MQTTConnection anotherConnection = createMQTTConnection(CONFIG); anotherConnection.processConnect(msg); EmbeddedChannel anotherChannel = (EmbeddedChannel) anotherConnection.channel; - assertEqualsConnAck(CONNECTION_ACCEPTED, anotherChannel.readOutbound()); + + Optional connAckMessage2 = NettyChannelAssertions.retry(anotherChannel::readOutbound, Objects::nonNull, 5); + assertEqualsConnAck(CONNECTION_ACCEPTED, connAckMessage2.orElse(null)); // Verify assertFalse("First 'FAKE_CLIENT_ID' channel MUST be closed by the broker", channel.isOpen()); diff --git a/broker/src/test/java/io/moquette/broker/MockAuthenticator.java b/broker/src/test/java/io/moquette/broker/MockAuthenticator.java index 1811270a5..983422ae9 100644 --- a/broker/src/test/java/io/moquette/broker/MockAuthenticator.java +++ b/broker/src/test/java/io/moquette/broker/MockAuthenticator.java @@ -16,9 +16,11 @@ package io.moquette.broker; +import io.moquette.broker.security.IAuthenticator; + import java.util.Map; import java.util.Set; -import io.moquette.broker.security.IAuthenticator; +import java.util.concurrent.CompletableFuture; import static java.nio.charset.StandardCharsets.UTF_8; @@ -36,7 +38,11 @@ public MockAuthenticator(Set clientIds, Map userPwds) { } @Override - public boolean checkValid(String clientId, String username, byte[] password) { + public CompletableFuture checkValid(String clientId, String username, byte[] password) { + return CompletableFuture.supplyAsync(() -> check(clientId, username, password)); + } + + private boolean check(String clientId, String username, byte[] password) { if (!m_clientIds.contains(clientId)) { return false; } @@ -48,5 +54,4 @@ public boolean checkValid(String clientId, String username, byte[] password) { } return m_userPwds.get(username).equals(new String(password, UTF_8)); } - } diff --git a/broker/src/test/java/io/moquette/broker/NettyChannelAssertions.java b/broker/src/test/java/io/moquette/broker/NettyChannelAssertions.java index 2eac65704..ec8000492 100644 --- a/broker/src/test/java/io/moquette/broker/NettyChannelAssertions.java +++ b/broker/src/test/java/io/moquette/broker/NettyChannelAssertions.java @@ -20,6 +20,11 @@ import io.netty.handler.codec.mqtt.MqttConnAckMessage; import io.netty.handler.codec.mqtt.MqttConnectReturnCode; import io.netty.handler.codec.mqtt.MqttSubAckMessage; + +import java.util.Optional; +import java.util.function.Predicate; +import java.util.function.Supplier; + import static io.netty.handler.codec.mqtt.MqttConnectReturnCode.CONNECTION_ACCEPTED; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -56,4 +61,15 @@ public static void assertEqualsSubAck(/* byte expectedCode, */ Object subAck) { private NettyChannelAssertions() { } + + static Optional retry(Supplier op, Predicate predicate, int maxAttempts) { + A a = op.get(); + + if (maxAttempts == 0) return Optional.empty(); + else if (predicate.test(a)) return Optional.of(a); + else { + try { Thread.sleep(50L); } catch (InterruptedException ie) {} + return retry(op, predicate, maxAttempts - 1); + } + } } diff --git a/broker/src/test/java/io/moquette/broker/config/ClasspathResourceLoaderTest.java b/broker/src/test/java/io/moquette/broker/config/ClasspathResourceLoaderTest.java index 58b27d574..5653c214d 100644 --- a/broker/src/test/java/io/moquette/broker/config/ClasspathResourceLoaderTest.java +++ b/broker/src/test/java/io/moquette/broker/config/ClasspathResourceLoaderTest.java @@ -20,16 +20,30 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; public class ClasspathResourceLoaderTest { @Test public void testSetProperties() { - IResourceLoader classpathLoader = new ClasspathResourceLoader(); - final IConfig classPathConfig = new ResourceLoaderConfig(classpathLoader); + IConfig classPathConfig = initConfig(); + assertEquals("" + BrokerConstants.PORT, classPathConfig.getProperty(BrokerConstants.PORT_PROPERTY_NAME)); classPathConfig.setProperty(BrokerConstants.PORT_PROPERTY_NAME, "9999"); assertEquals("9999", classPathConfig.getProperty(BrokerConstants.PORT_PROPERTY_NAME)); } + @Test + public void testSetAuthThreadPoolSize() { + IConfig classPathConfig = initConfig(); + + assertNull(classPathConfig.getProperty(BrokerConstants.AUTH_THREAD_POOL_SIZE)); + classPathConfig.setProperty(BrokerConstants.AUTH_THREAD_POOL_SIZE, "2"); + assertEquals("2", classPathConfig.getProperty(BrokerConstants.AUTH_THREAD_POOL_SIZE)); + } + + + private IConfig initConfig() { + return new ResourceLoaderConfig(new ClasspathResourceLoader()); + } } diff --git a/broker/src/test/java/io/moquette/broker/security/DBAuthenticatorTest.java b/broker/src/test/java/io/moquette/broker/security/DBAuthenticatorTest.java index 9bb68ee73..048b8e8f3 100644 --- a/broker/src/test/java/io/moquette/broker/security/DBAuthenticatorTest.java +++ b/broker/src/test/java/io/moquette/broker/security/DBAuthenticatorTest.java @@ -17,13 +17,19 @@ package io.moquette.broker.security; import org.apache.commons.codec.binary.Hex; -import org.junit.*; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; -import java.sql.*; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertFalse; @@ -71,8 +77,9 @@ public void Db_verifyValid() { ORG_H2_DRIVER, JDBC_H2_MEM_TEST, "SELECT PASSWORD FROM ACCOUNT WHERE LOGIN=?", - SHA_256); - assertTrue(dbAuthenticator.checkValid(null, "dbuser", "password".getBytes(UTF_8))); + SHA_256, + 1); + assertTrue(dbAuthenticator.checkValid(null, "dbuser", "password".getBytes(UTF_8)).join()); } @Test @@ -81,8 +88,9 @@ public void Db_verifyInvalidLogin() { ORG_H2_DRIVER, JDBC_H2_MEM_TEST, "SELECT PASSWORD FROM ACCOUNT WHERE LOGIN=?", - SHA_256); - assertFalse(dbAuthenticator.checkValid(null, "dbuser2", "password".getBytes(UTF_8))); + SHA_256, + 1); + assertFalse(dbAuthenticator.checkValid(null, "dbuser2", "password".getBytes(UTF_8)).join()); } @Test @@ -91,8 +99,9 @@ public void Db_verifyInvalidPassword() { ORG_H2_DRIVER, JDBC_H2_MEM_TEST, "SELECT PASSWORD FROM ACCOUNT WHERE LOGIN=?", - SHA_256); - assertFalse(dbAuthenticator.checkValid(null, "dbuser", "wrongPassword".getBytes(UTF_8))); + SHA_256, + 1); + assertFalse(dbAuthenticator.checkValid(null, "dbuser", "wrongPassword".getBytes(UTF_8)).join()); } @After diff --git a/broker/src/test/java/io/moquette/broker/security/FileAuthenticatorTest.java b/broker/src/test/java/io/moquette/broker/security/FileAuthenticatorTest.java index 5e17b7c88..cba1f2e7c 100644 --- a/broker/src/test/java/io/moquette/broker/security/FileAuthenticatorTest.java +++ b/broker/src/test/java/io/moquette/broker/security/FileAuthenticatorTest.java @@ -16,8 +16,12 @@ package io.moquette.broker.security; +import io.moquette.broker.config.IConfig; +import io.moquette.broker.config.MemoryConfig; import org.junit.Test; +import java.util.Properties; + import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -25,27 +29,29 @@ @SuppressWarnings("deprecation") public class FileAuthenticatorTest { + private IConfig config = new MemoryConfig(new Properties()); + @Test public void loadPasswordFile_verifyValid() { String file = getClass().getResource("/password_file.conf").getPath(); - IAuthenticator auth = new FileAuthenticator(null, file); + IAuthenticator auth = new FileAuthenticator(null, file, config); - assertTrue(auth.checkValid(null, "testuser", "passwd".getBytes(UTF_8))); + assertTrue(auth.checkValid(null, "testuser", "passwd".getBytes(UTF_8)).join()); } @Test public void loadPasswordFile_verifyInvalid() { String file = getClass().getResource("/password_file.conf").getPath(); - IAuthenticator auth = new FileAuthenticator(null, file); + IAuthenticator auth = new FileAuthenticator(null, file, config); - assertFalse(auth.checkValid(null, "testuser2", "passwd".getBytes(UTF_8))); + assertFalse(auth.checkValid(null, "testuser2", "passwd".getBytes(UTF_8)).join()); } @Test public void loadPasswordFile_verifyDirectoryRef() { - IAuthenticator auth = new FileAuthenticator("", ""); + IAuthenticator auth = new FileAuthenticator("", "", config); - assertFalse(auth.checkValid(null, "testuser2", "passwd".getBytes(UTF_8))); + assertFalse(auth.checkValid(null, "testuser2", "passwd".getBytes(UTF_8)).join()); } } diff --git a/broker/src/test/java/io/moquette/integration/ConfigurationClassLoaderTest.java b/broker/src/test/java/io/moquette/integration/ConfigurationClassLoaderTest.java index 236e3c2f9..7b32091dd 100644 --- a/broker/src/test/java/io/moquette/integration/ConfigurationClassLoaderTest.java +++ b/broker/src/test/java/io/moquette/integration/ConfigurationClassLoaderTest.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.util.Properties; +import java.util.concurrent.CompletableFuture; import static org.junit.Assert.assertTrue; @@ -64,8 +65,8 @@ public void loadAuthorizator() throws Exception { } @Override - public boolean checkValid(String clientID, String username, byte[] password) { - return true; + public CompletableFuture checkValid(String clientID, String username, byte[] password) { + return CompletableFuture.completedFuture(true); } @Override