diff --git a/.github/workflows/ReleaseNotes.md b/.github/workflows/ReleaseNotes.md index 069f272e4..d499d16e0 100644 --- a/.github/workflows/ReleaseNotes.md +++ b/.github/workflows/ReleaseNotes.md @@ -1,4 +1,6 @@ * [Client] Fixed _PlatformNotSupportedException_ when using Blazor (#1755, thanks to @Nickztar). +* [Client] Added hot reload of client certificates (#1781). +* [Client] Added several new option builders and aligned usage (#1781, BREAKING CHANGE!). * [Client] Added support for _RemoteCertificateValidationCallback_ for .NET 4.5.2, 4.6.1 and 4.8 (#1806, thanks to @troky). * [Client] Fixed wrong logging of obsolete feature when connection was not successful (#1801, thanks to @ramonsmits). * [Client] Fixed _NullReferenceException_ when performing several actions when not connected (#1800, thanks to @ramonsmits). diff --git a/Samples/Client/Client_Connection_Samples.cs b/Samples/Client/Client_Connection_Samples.cs index d74ccb855..6fca63398 100644 --- a/Samples/Client/Client_Connection_Samples.cs +++ b/Samples/Client/Client_Connection_Samples.cs @@ -70,32 +70,6 @@ public static async Task Connect_Client() await mqttClient.DisconnectAsync(mqttClientDisconnectOptions, CancellationToken.None); } } - - public static async Task Connect_With_Amazon_AWS() - { - /* - * This sample creates a simple MQTT client and connects to an Amazon Web Services broker. - * - * The broker requires special settings which are set here. - */ - - var mqttFactory = new MqttFactory(); - - using (var mqttClient = mqttFactory.CreateMqttClient()) - { - var mqttClientOptions = new MqttClientOptionsBuilder() - .WithTcpServer("amazon.web.services.broker") - // Disabling packet fragmentation is very important! - .WithoutPacketFragmentation() - .Build(); - - await mqttClient.ConnectAsync(mqttClientOptions, CancellationToken.None); - - Console.WriteLine("The MQTT client is connected."); - - await mqttClient.DisconnectAsync(); - } - } public static async Task Connect_Client_Timeout() { @@ -161,15 +135,15 @@ public static async Task Connect_Client_Using_TLS_1_2() using (var mqttClient = mqttFactory.CreateMqttClient()) { var mqttClientOptions = new MqttClientOptionsBuilder().WithTcpServer("mqtt.fluux.io") - .WithTls( + .WithTlsOptions( o => { // The used public broker sometimes has invalid certificates. This sample accepts all // certificates. This should not be used in live environments. - o.CertificateValidationHandler = _ => true; + o.WithCertificateValidationHandler(_ => true); // The default value is determined by the OS. Set manually to force version. - o.SslProtocol = SslProtocols.Tls12; + o.WithSslProtocols(SslProtocols.Tls12); }) .Build(); @@ -196,7 +170,7 @@ public static async Task Connect_Client_Using_WebSocket4Net() using (var mqttClient = mqttFactory.CreateMqttClient()) { - var mqttClientOptions = new MqttClientOptionsBuilder().WithWebSocketServer("broker.hivemq.com:8000/mqtt").Build(); + var mqttClientOptions = new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.hivemq.com:8000/mqtt")).Build(); var response = await mqttClient.ConnectAsync(mqttClientOptions, CancellationToken.None); @@ -218,7 +192,7 @@ public static async Task Connect_Client_Using_WebSockets() using (var mqttClient = mqttFactory.CreateMqttClient()) { - var mqttClientOptions = new MqttClientOptionsBuilder().WithWebSocketServer("broker.hivemq.com:8000/mqtt").Build(); + var mqttClientOptions = new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.hivemq.com:8000/mqtt")).Build(); var response = await mqttClient.ConnectAsync(mqttClientOptions, CancellationToken.None); @@ -241,13 +215,11 @@ public static async Task Connect_Client_With_TLS_Encryption() using (var mqttClient = mqttFactory.CreateMqttClient()) { var mqttClientOptions = new MqttClientOptionsBuilder().WithTcpServer("test.mosquitto.org", 8883) - .WithTls( - o => - { + .WithTlsOptions( + o => o.WithCertificateValidationHandler( // The used public broker sometimes has invalid certificates. This sample accepts all // certificates. This should not be used in live environments. - o.CertificateValidationHandler = _ => true; - }) + _ => true)) .Build(); // In MQTTv5 the response contains much more information. @@ -262,6 +234,31 @@ public static async Task Connect_Client_With_TLS_Encryption() } } + public static async Task Connect_With_Amazon_AWS() + { + /* + * This sample creates a simple MQTT client and connects to an Amazon Web Services broker. + * + * The broker requires special settings which are set here. + */ + + var mqttFactory = new MqttFactory(); + + using (var mqttClient = mqttFactory.CreateMqttClient()) + { + var mqttClientOptions = new MqttClientOptionsBuilder().WithTcpServer("amazon.web.services.broker") + // Disabling packet fragmentation is very important! + .WithoutPacketFragmentation() + .Build(); + + await mqttClient.ConnectAsync(mqttClientOptions, CancellationToken.None); + + Console.WriteLine("The MQTT client is connected."); + + await mqttClient.DisconnectAsync(); + } + } + public static async Task Disconnect_Clean() { /* @@ -317,18 +314,19 @@ public static async Task Inspect_Certificate_Validation_Errors() using (var mqttClient = mqttFactory.CreateMqttClient()) { var mqttClientOptions = new MqttClientOptionsBuilder().WithTcpServer("mqtt.fluux.io", 8883) - .WithTls( + .WithTlsOptions( o => { - o.CertificateValidationHandler = eventArgs => - { - eventArgs.Certificate.Subject.DumpToConsole(); - eventArgs.Certificate.GetExpirationDateString().DumpToConsole(); - eventArgs.Chain.ChainPolicy.RevocationMode.DumpToConsole(); - eventArgs.Chain.ChainStatus.DumpToConsole(); - eventArgs.SslPolicyErrors.DumpToConsole(); - return true; - }; + o.WithCertificateValidationHandler( + eventArgs => + { + eventArgs.Certificate.Subject.DumpToConsole(); + eventArgs.Certificate.GetExpirationDateString().DumpToConsole(); + eventArgs.Chain.ChainPolicy.RevocationMode.DumpToConsole(); + eventArgs.Chain.ChainStatus.DumpToConsole(); + eventArgs.SslPolicyErrors.DumpToConsole(); + return true; + }); }) .Build(); @@ -434,4 +432,4 @@ public static void Reconnect_Using_Timer() }); } } -} +} \ No newline at end of file diff --git a/Source/MQTTnet.Extensions.WebSocket4Net/WebSocket4NetMqttChannel.cs b/Source/MQTTnet.Extensions.WebSocket4Net/WebSocket4NetMqttChannel.cs index 859bd2a3b..b70735943 100644 --- a/Source/MQTTnet.Extensions.WebSocket4Net/WebSocket4NetMqttChannel.cs +++ b/Source/MQTTnet.Extensions.WebSocket4Net/WebSocket4NetMqttChannel.cs @@ -92,18 +92,7 @@ public Task ConnectAsync(CancellationToken cancellationToken) var webSocketVersion = WebSocketVersion.None; var receiveBufferSize = 0; - var certificates = new X509CertificateCollection(); - if (_webSocketOptions.TlsOptions?.Certificates != null) - { - foreach (var certificate in _webSocketOptions.TlsOptions.Certificates) - { -#if WINDOWS_UWP - certificates.Add(new X509Certificate(certificate)); -#else - certificates.Add(certificate); -#endif - } - } + var certificates = _webSocketOptions.TlsOptions?.ClientCertificatesProvider?.GetCertificates(); _webSocket = new WebSocket(uri, subProtocol, cookies, customHeaders, userAgent, origin, webSocketVersion, proxy, sslProtocols, receiveBufferSize) { diff --git a/Source/MQTTnet.TestApp/PublicBrokerTest.cs b/Source/MQTTnet.TestApp/PublicBrokerTest.cs index f95c332ef..1659f2610 100644 --- a/Source/MQTTnet.TestApp/PublicBrokerTest.cs +++ b/Source/MQTTnet.TestApp/PublicBrokerTest.cs @@ -20,7 +20,7 @@ public static async Task RunAsync() { #if NET5_0_OR_GREATER // TLS13 is only available in Net5.0 - var unsafeTls13 = new MqttClientOptionsBuilderTlsParameters + var unsafeTls13 = new MqttClientTlsOptions { UseTls = true, SslProtocol = SslProtocols.Tls13, @@ -29,7 +29,7 @@ public static async Task RunAsync() }; #endif // Also defining TLS12 for servers that don't seem no to support TLS13. - var unsafeTls12 = new MqttClientOptionsBuilderTlsParameters + var unsafeTls12 = new MqttClientTlsOptions { UseTls = true, SslProtocol = SslProtocols.Tls12, @@ -44,16 +44,16 @@ await ExecuteTestAsync( await ExecuteTestAsync( "mqtt.eclipseprojects.io WS", - new MqttClientOptionsBuilder().WithWebSocketServer("mqtt.eclipseprojects.io:80/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build()); + new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("mqtt.eclipseprojects.io:80/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build()); #if NET5_0_OR_GREATER await ExecuteTestAsync("mqtt.eclipseprojects.io WS TLS13", - new MqttClientOptionsBuilder().WithWebSocketServer("mqtt.eclipseprojects.io:443/mqtt") - .WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls13).Build()); + new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("mqtt.eclipseprojects.io:443/mqtt")) + .WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls13).Build()); await ExecuteTestAsync("mqtt.eclipseprojects.io WS TLS13 (WebSocket4Net)", - new MqttClientOptionsBuilder().WithWebSocketServer("mqtt.eclipseprojects.io:443/mqtt") - .WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls13).Build(), + new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("mqtt.eclipseprojects.io:443/mqtt")) + .WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls13).Build(), true); #endif @@ -68,12 +68,12 @@ await ExecuteTestAsync( await ExecuteTestAsync( "test.mosquitto.org TCP TLS12", - new MqttClientOptionsBuilder().WithTcpServer("test.mosquitto.org", 8883).WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls12).Build()); + new MqttClientOptionsBuilder().WithTcpServer("test.mosquitto.org", 8883).WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls12).Build()); #if NET5_0_OR_GREATER await ExecuteTestAsync("test.mosquitto.org TCP TLS13", new MqttClientOptionsBuilder().WithTcpServer("test.mosquitto.org", 8883) - .WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls13).Build()); + .WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls13).Build()); #endif await ExecuteTestAsync( @@ -81,21 +81,21 @@ await ExecuteTestAsync( new MqttClientOptionsBuilder().WithTcpServer("test.mosquitto.org", 8885) .WithCredentials("rw", "readwrite") .WithProtocolVersion(MqttProtocolVersion.V311) - .WithTls(unsafeTls12) + .WithTlsOptions(unsafeTls12) .Build()); await ExecuteTestAsync( "test.mosquitto.org WS", - new MqttClientOptionsBuilder().WithWebSocketServer("test.mosquitto.org:8080/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build()); + new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("test.mosquitto.org:8080/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build()); await ExecuteTestAsync( "test.mosquitto.org WS (WebSocket4Net)", - new MqttClientOptionsBuilder().WithWebSocketServer("test.mosquitto.org:8080/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build(), + new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("test.mosquitto.org:8080/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build(), true); await ExecuteTestAsync( "test.mosquitto.org WS TLS12", - new MqttClientOptionsBuilder().WithWebSocketServer("test.mosquitto.org:8081/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls12).Build()); + new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("test.mosquitto.org:8081/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls12).Build()); // await ExecuteTestAsync( // "test.mosquitto.org WS TLS12 (WebSocket4Net)", @@ -109,30 +109,30 @@ await ExecuteTestAsync( await ExecuteTestAsync( "broker.emqx.io TCP TLS12", - new MqttClientOptionsBuilder().WithTcpServer("broker.emqx.io", 8883).WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls12).Build()); + new MqttClientOptionsBuilder().WithTcpServer("broker.emqx.io", 8883).WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls12).Build()); #if NET5_0_OR_GREATER await ExecuteTestAsync("broker.emqx.io TCP TLS13", new MqttClientOptionsBuilder().WithTcpServer("broker.emqx.io", 8883) - .WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls13).Build()); + .WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls13).Build()); #endif await ExecuteTestAsync( "broker.emqx.io WS", - new MqttClientOptionsBuilder().WithWebSocketServer("broker.emqx.io:8083/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build()); + new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.emqx.io:8083/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build()); await ExecuteTestAsync( "broker.emqx.io WS (WebSocket4Net)", - new MqttClientOptionsBuilder().WithWebSocketServer("broker.emqx.io:8084/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build(), + new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.emqx.io:8084/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build(), true); await ExecuteTestAsync( "broker.emqx.io WS TLS12", - new MqttClientOptionsBuilder().WithWebSocketServer("broker.emqx.io:8084/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls12).Build()); + new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.emqx.io:8084/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls12).Build()); await ExecuteTestAsync( "broker.emqx.io WS TLS12 (WebSocket4Net)", - new MqttClientOptionsBuilder().WithWebSocketServer("broker.emqx.io:8084/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls12).Build(), + new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.emqx.io:8084/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls12).Build(), true); // broker.hivemq.com @@ -142,11 +142,11 @@ await ExecuteTestAsync( await ExecuteTestAsync( "broker.hivemq.com WS", - new MqttClientOptionsBuilder().WithWebSocketServer("broker.hivemq.com:8000/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build()); + new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.hivemq.com:8000/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build()); await ExecuteTestAsync( "broker.hivemq.com WS (WebSocket4Net)", - new MqttClientOptionsBuilder().WithWebSocketServer("broker.hivemq.com:8000/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build(), + new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.hivemq.com:8000/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build(), true); // mqtt.swifitch.cz: Does not seem to operate any more diff --git a/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs b/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs index 665519836..376e800b6 100644 --- a/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/ManagedMqttClient/ManagedMqttClient_Tests.cs @@ -421,7 +421,7 @@ public async Task Subscriptions_Are_Published_Immediately() var receivingClient = await CreateManagedClientAsync(testEnvironment, null, connectionCheckInterval); var sendingClient = await testEnvironment.ConnectClient(); - await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", Payload = new byte[] { 1 }, Retain = true }); + await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", PayloadSegment = new ArraySegment( new byte[] { 1 }), Retain = true }); var subscribeTime = DateTime.UtcNow; @@ -454,7 +454,7 @@ public async Task Subscriptions_Subscribe_Only_New_Subscriptions() //wait a bit for the subscription to become established await Task.Delay(500); - await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", Payload = new byte[] { 1 }, Retain = true }); + await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", PayloadSegment = new ArraySegment(new byte[] { 1 }), Retain = true }); var messages = await SetupReceivingOfMessages(managedClient, 1); diff --git a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs index 7b65e07f7..6894b0102 100644 --- a/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs +++ b/Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs @@ -296,7 +296,7 @@ await receiver.SubscribeAsync( Assert.IsNotNull(receivedMessage); Assert.AreEqual("A", receivedMessage.Topic); - Assert.AreEqual(null, receivedMessage.Payload); + Assert.AreEqual(null, receivedMessage.PayloadSegment.Array); } } @@ -507,7 +507,7 @@ public async Task Publish_QoS_1_In_ApplicationMessageReceiveHandler() client2.ApplicationMessageReceivedAsync += e => { - client2TopicResults.Add(Encoding.UTF8.GetString(e.ApplicationMessage.Payload)); + client2TopicResults.Add(Encoding.UTF8.GetString(e.ApplicationMessage.PayloadSegment.ToArray())); return CompletedTask.Instance; }; diff --git a/Source/MQTTnet.Tests/Extensions/WebSocket4Net_Tests.cs b/Source/MQTTnet.Tests/Extensions/WebSocket4Net_Tests.cs index 8c70fef08..49adede65 100644 --- a/Source/MQTTnet.Tests/Extensions/WebSocket4Net_Tests.cs +++ b/Source/MQTTnet.Tests/Extensions/WebSocket4Net_Tests.cs @@ -23,7 +23,7 @@ public async Task Connect_Failed_With_Invalid_Server() using (var client = factory.CreateMqttClient()) { - var options = new MqttClientOptionsBuilder().WithWebSocketServer("ws://a.b/mqtt").WithTimeout(TimeSpan.FromSeconds(2)).Build(); + var options = new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("ws://a.b/mqtt")).WithTimeout(TimeSpan.FromSeconds(2)).Build(); await client.ConnectAsync(options).ConfigureAwait(false); } } diff --git a/Source/MQTTnet.Tests/MqttApplicationMessageBuilder_Tests.cs b/Source/MQTTnet.Tests/MqttApplicationMessageBuilder_Tests.cs index 50aa44752..f4f28f0bc 100644 --- a/Source/MQTTnet.Tests/MqttApplicationMessageBuilder_Tests.cs +++ b/Source/MQTTnet.Tests/MqttApplicationMessageBuilder_Tests.cs @@ -4,6 +4,7 @@ using System; using System.IO; +using System.Linq; using System.Text; using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Protocol; @@ -11,7 +12,7 @@ namespace MQTTnet.Tests { [TestClass] - public class MqttApplicationMessageBuilder_Tests + public sealed class MqttApplicationMessageBuilder_Tests { [TestMethod] public void CreateApplicationMessage_TopicOnly() @@ -29,7 +30,7 @@ public void CreateApplicationMessage_TimeStampPayload() Assert.AreEqual("xyz", message.Topic); Assert.IsFalse(message.Retain); Assert.AreEqual(MqttQualityOfServiceLevel.AtMostOnce, message.QualityOfServiceLevel); - Assert.AreEqual(Encoding.UTF8.GetString(message.Payload), "00:06:00"); + Assert.AreEqual(Encoding.UTF8.GetString(message.PayloadSegment.ToArray()), "00:06:00"); } [TestMethod] @@ -41,7 +42,7 @@ public void CreateApplicationMessage_StreamPayload() Assert.AreEqual("123", message.Topic); Assert.IsFalse(message.Retain); Assert.AreEqual(MqttQualityOfServiceLevel.AtMostOnce, message.QualityOfServiceLevel); - Assert.AreEqual(Encoding.UTF8.GetString(message.Payload), "Hello"); + Assert.AreEqual(Encoding.UTF8.GetString(message.PayloadSegment.ToArray()), "Hello"); } [TestMethod] diff --git a/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs b/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs new file mode 100644 index 000000000..64282ffb2 --- /dev/null +++ b/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs @@ -0,0 +1,445 @@ +using System; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Linq; +using System.Net; +using System.Net.Security; +using System.Security.Authentication; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Certificates; +using MQTTnet.Client; +using MQTTnet.Extensions.ManagedClient; +using MQTTnet.Formatter; +using MQTTnet.Protocol; +using MQTTnet.Server; + +namespace MQTTnet.Tests.Server +{ + // missing certificate builder api means tests won't work for older frameworks +#if !(NET452 || NET461) + [TestClass] +#endif + public sealed class HotSwapCerts_Tests + { + readonly TimeSpan DEFAULT_TIMEOUT = TimeSpan.FromSeconds(10); + + [TestMethod] + public void ClientCertChangeWithoutServerUpdateFailsReconnect() + { + using (var server = new ServerTestHarness()) + using (var client01 = new ClientTestHarness()) + { + server.InstallNewClientCert(client01.GetCurrentClientCert()); + client01.InstallNewServerCert(server.GetCurrentServerCert()); + + server.StartServer().Wait(); + + client01.Connect(); + + client01.WaitForConnectOrFail(DEFAULT_TIMEOUT); + + client01.HotSwapClientCert(); + server.ForceDisconnectAsync(client01).Wait(DEFAULT_TIMEOUT); + client01.WaitForDisconnectOrFail(DEFAULT_TIMEOUT); + + client01.WaitForConnectToFail(DEFAULT_TIMEOUT); + } + } + + [TestMethod] + public void ClientCertChangeWithServerUpdateAcceptsReconnect() + { + using (var server = new ServerTestHarness()) + using (var client01 = new ClientTestHarness()) + { + server.InstallNewClientCert(client01.GetCurrentClientCert()); + client01.InstallNewServerCert(server.GetCurrentServerCert()); + + server.StartServer().Wait(); + client01.Connect(); + + client01.WaitForConnectOrFail(DEFAULT_TIMEOUT); + + client01.HotSwapClientCert(); + server.ForceDisconnectAsync(client01).Wait(DEFAULT_TIMEOUT); + client01.WaitForDisconnectOrFail(DEFAULT_TIMEOUT); + + server.InstallNewClientCert(client01.GetCurrentClientCert()); + + client01.WaitForConnectOrFail(DEFAULT_TIMEOUT); + } + } + + [TestMethod] + public void ServerCertChangeWithClientCertUpdateAllowsReconnect() + { + using (var server = new ServerTestHarness()) + using (var client01 = new ClientTestHarness()) + { + server.InstallNewClientCert(client01.GetCurrentClientCert()); + client01.InstallNewServerCert(server.GetCurrentServerCert()); + + server.StartServer().Wait(); + client01.Connect(); + + client01.WaitForConnectOrFail(DEFAULT_TIMEOUT); + server.HotSwapServerCert(); + + server.ForceDisconnectAsync(client01).Wait(DEFAULT_TIMEOUT); + client01.WaitForDisconnectOrFail(DEFAULT_TIMEOUT); + client01.InstallNewServerCert(server.GetCurrentServerCert()); + + client01.WaitForConnectOrFail(DEFAULT_TIMEOUT); + } + } + + [TestMethod] + public void ServerCertChangeWithoutClientCertUpdateFailsReconnect() + { + using (var server = new ServerTestHarness()) + using (var client01 = new ClientTestHarness()) + { + server.InstallNewClientCert(client01.GetCurrentClientCert()); + client01.InstallNewServerCert(server.GetCurrentServerCert()); + + server.StartServer().Wait(); + client01.Connect(); + + client01.WaitForConnectOrFail(DEFAULT_TIMEOUT); + server.HotSwapServerCert(); + + server.ForceDisconnectAsync(client01).Wait(DEFAULT_TIMEOUT); + client01.WaitForDisconnectOrFail(DEFAULT_TIMEOUT); + + client01.WaitForConnectToFail(DEFAULT_TIMEOUT); + } + } + + static X509Certificate2 CreateSelfSignedCertificate(string oid) + { +#if NET452 || NET461 + throw new NotImplementedException(); +#else + var sanBuilder = new SubjectAlternativeNameBuilder(); + sanBuilder.AddIpAddress(IPAddress.Loopback); + sanBuilder.AddIpAddress(IPAddress.IPv6Loopback); + sanBuilder.AddDnsName("localhost"); + + using (var rsa = RSA.Create()) + { + var certRequest = new CertificateRequest("CN=localhost", rsa, HashAlgorithmName.SHA512, RSASignaturePadding.Pkcs1); + + certRequest.CertificateExtensions.Add( + new X509KeyUsageExtension(X509KeyUsageFlags.DataEncipherment | X509KeyUsageFlags.KeyEncipherment | X509KeyUsageFlags.DigitalSignature, false)); + + certRequest.CertificateExtensions.Add(new X509EnhancedKeyUsageExtension(new OidCollection { new Oid(oid) }, false)); + + certRequest.CertificateExtensions.Add(sanBuilder.Build()); + + using (var certificate = certRequest.CreateSelfSigned(DateTimeOffset.Now.AddMinutes(-10), DateTimeOffset.Now.AddMinutes(10))) + { + var pfxCertificate = new X509Certificate2( + certificate.Export(X509ContentType.Pfx), + (string)null, + X509KeyStorageFlags.MachineKeySet | X509KeyStorageFlags.Exportable); + + return pfxCertificate; + } + } +#endif + } + + class ClientTestHarness : IDisposable + { + IManagedMqttClient _client; + readonly HotSwappableClientCertProvider _hotSwapClient = new HotSwappableClientCertProvider(); + + public string ClientID => _client.InternalClient.Options.ClientId; + + public void ClearServerCerts() + { + _hotSwapClient.ClearServerCerts(); + } + + public void Connect() + { + Run_Client_Connection().Wait(); + } + + public void Dispose() + { + _client.Dispose(); + } + + public X509Certificate2 GetCurrentClientCert() + { + var result = _hotSwapClient.GetCertificates()[0]; + return new X509Certificate2(result); + } + + public void HotSwapClientCert() + { + _hotSwapClient.HotSwapCert(); + } + + public void InstallNewServerCert(X509Certificate2 serverCert) + { + _hotSwapClient.InstallNewServerCert(serverCert); + } + + public void WaitForConnect(TimeSpan timeout) + { + var timer = Stopwatch.StartNew(); + while ((_client == null || !_client.IsConnected) && timer.Elapsed < timeout) + { + Thread.Sleep(5); + } + } + + public void WaitForConnectOrFail(TimeSpan timeout) + { + Assert.IsFalse(_client.IsConnected, "Client should be disconnected before waiting for connect."); + + WaitForConnect(timeout); + + Assert.IsNotNull(_client, "Client was never initialized"); + Assert.IsTrue(_client.IsConnected, $"Client connection failed after {timeout}"); + } + + public void WaitForConnectToFail(TimeSpan timeout) + { + Assert.IsFalse(_client.IsConnected, "Client should be disconnected before waiting for connect."); + + WaitForConnect(timeout); + + Assert.IsNotNull(_client, "Client was never initialized"); + Assert.IsFalse(_client.IsConnected, "Client connection success but test wanted fail"); + } + + public void WaitForDisconnect(TimeSpan timeout) + { + var timer = Stopwatch.StartNew(); + while ((_client == null || _client.IsConnected) && timer.Elapsed < timeout) + { + Thread.Sleep(5); + } + } + + public void WaitForDisconnectOrFail(TimeSpan timeout) + { + WaitForConnect(timeout); + + Assert.IsNotNull(_client, "Client was never initialized"); + Assert.IsFalse(_client.IsConnected, $"Client connection should have disconnected after {timeout}"); + } + + async Task Run_Client_Connection() + { + var optionsBuilder = new MqttClientOptionsBuilder() + .WithTlsOptions( + o => o.WithClientCertificatesProvider(_hotSwapClient) + .WithCertificateValidationHandler(_hotSwapClient.OnCertifciateValidation) + .WithSslProtocols(SslProtocols.Tls12)) + .WithTcpServer("localhost") + .WithCleanSession() + .WithProtocolVersion(MqttProtocolVersion.V500); + var mqttClientOptions = optionsBuilder.Build(); + + var managedClientOptionsBuilder = new ManagedMqttClientOptionsBuilder().WithClientOptions(mqttClientOptions); + var managedClientOptions = managedClientOptionsBuilder.Build(); + + var factory = new MqttFactory(); + var mqttClient = factory.CreateManagedMqttClient(); + _client = mqttClient; + + await mqttClient.StartAsync(managedClientOptions); + } + } + + class ServerTestHarness : IDisposable + { + CancellationTokenSource _cts = new CancellationTokenSource(); + readonly HotSwappableServerCertProvider _hotSwapServer = new HotSwappableServerCertProvider(); + MqttServer _server; + + public void ClearClientCerts() + { + _hotSwapServer.ClearClientCerts(); + } + + public void Dispose() + { + if (_server != null) + { + _server.StopAsync().Wait(); + _server.Dispose(); + } + + if (_hotSwapServer != null) + { + _hotSwapServer.Dispose(); + } + } + + public async Task ForceDisconnectAsync(ClientTestHarness client) + { + await _server.DisconnectClientAsync(client.ClientID, MqttDisconnectReasonCode.UnspecifiedError); + } + + public X509Certificate2 GetCurrentServerCert() + { + return _hotSwapServer.GetCertificate(); + } + + public void HotSwapServerCert() + { + _hotSwapServer.HotSwapCert(); + } + + public void InstallNewClientCert(X509Certificate2 serverCert) + { + _hotSwapServer.InstallNewClientCert(serverCert); + } + + public async Task StartServer() + { + var mqttFactory = new MqttFactory(); + + var mqttServerOptions = new MqttServerOptionsBuilder().WithEncryptionCertificate(_hotSwapServer) + .WithRemoteCertificateValidationCallback(_hotSwapServer.RemoteCertificateValidationCallback) + .WithEncryptedEndpoint() + .Build(); + mqttServerOptions.TlsEndpointOptions.ClientCertificateRequired = true; + _server = mqttFactory.CreateMqttServer(mqttServerOptions); + await _server.StartAsync(); + } + } + + class HotSwappableClientCertProvider : IMqttClientCertificatesProvider + { + X509Certificate2Collection _certificates; + ConcurrentBag ServerCerts = new ConcurrentBag(); + + public HotSwappableClientCertProvider() + { + _certificates = new X509Certificate2Collection(CreateSelfSignedCertificate("1.3.6.1.5.5.7.3.2")); + } + + public void ClearServerCerts() + { + ServerCerts = new ConcurrentBag(); + } + + public X509CertificateCollection GetCertificates() + { + return new X509Certificate2Collection(_certificates); + } + + public void HotSwapCert() + { + var newCert = new X509Certificate2Collection(CreateSelfSignedCertificate("1.3.6.1.5.5.7.3.2")); + var oldCerts = Interlocked.Exchange(ref _certificates, newCert); + } + + public void InstallNewServerCert(X509Certificate2 serverCert) + { + ServerCerts.Add(serverCert); + } + + public bool OnCertifciateValidation(MqttClientCertificateValidationEventArgs certContext) + { + var serverCerts = ServerCerts.ToArray(); + + var providedCert = certContext.Certificate.GetRawCertData(); + for (int i = 0, n = serverCerts.Length; i < n; i++) + { + var currentcert = serverCerts[i]; + + if (currentcert.RawData.SequenceEqual(providedCert)) + { + return true; + } + } + + return false; + } + + void Dispose() + { + if (_certificates != null) + { + foreach (var certs in _certificates) + { +#if !NET452 + certs.Dispose(); +#endif + } + } + } + } + + class HotSwappableServerCertProvider : ICertificateProvider, IDisposable + { + X509Certificate2 _certificate; + ConcurrentBag ClientCerts = new ConcurrentBag(); + + public HotSwappableServerCertProvider() + { + _certificate = CreateSelfSignedCertificate("1.3.6.1.5.5.7.3.1"); + } + + public void ClearClientCerts() + { + ClientCerts = new ConcurrentBag(); + } + + public void Dispose() + { +#if !NET452 + _certificate.Dispose(); +#endif + } + + public X509Certificate2 GetCertificate() + { + return _certificate; + } + + public void HotSwapCert() + { + var newCert = CreateSelfSignedCertificate("1.3.6.1.5.5.7.3.1"); + var oldCert = Interlocked.Exchange(ref _certificate, newCert); +#if !NET452 + oldCert.Dispose(); +#endif + } + + public void InstallNewClientCert(X509Certificate2 certificate) + { + ClientCerts.Add(certificate); + } + + public bool RemoteCertificateValidationCallback(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors) + { + var serverCerts = ClientCerts.ToArray(); + + var providedCert = certificate.GetRawCertData(); + for (int i = 0, n = serverCerts.Length; i < n; i++) + { + var currentcert = serverCerts[i]; + + if (currentcert.RawData.SequenceEqual(providedCert)) + { + return true; + } + } + + return false; + } + } + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Client/Options/DefaultMqttCertificatesProvider.cs b/Source/MQTTnet/Client/Options/DefaultMqttCertificatesProvider.cs new file mode 100644 index 000000000..7c3949d91 --- /dev/null +++ b/Source/MQTTnet/Client/Options/DefaultMqttCertificatesProvider.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Security.Cryptography.X509Certificates; + +namespace MQTTnet.Client +{ + public sealed class DefaultMqttCertificatesProvider : IMqttClientCertificatesProvider + { + readonly X509Certificate2Collection _certificates; + + public DefaultMqttCertificatesProvider(X509Certificate2Collection certificates) + { + _certificates = certificates; + } + + public DefaultMqttCertificatesProvider(IEnumerable certificates) + { + if (certificates != null) + { + _certificates = new X509Certificate2Collection(); + foreach (var certificate in certificates) + { + _certificates.Add(certificate); + } + } + } + + public X509CertificateCollection GetCertificates() + { + return _certificates; + } + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Client/Options/IMqttClientCertificatesProvider.cs b/Source/MQTTnet/Client/Options/IMqttClientCertificatesProvider.cs new file mode 100644 index 000000000..b90017392 --- /dev/null +++ b/Source/MQTTnet/Client/Options/IMqttClientCertificatesProvider.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace MQTTnet.Client +{ + public interface IMqttClientCertificatesProvider + { + System.Security.Cryptography.X509Certificates.X509CertificateCollection GetCertificates(); + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs index 805eda661..768d87541 100644 --- a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs +++ b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs @@ -15,10 +15,14 @@ namespace MQTTnet.Client public sealed class MqttClientOptionsBuilder { readonly MqttClientOptions _options = new MqttClientOptions(); - MqttClientWebSocketProxyOptions _proxyOptions; + + [Obsolete] MqttClientWebSocketProxyOptions _proxyOptions; MqttClientTcpOptions _tcpOptions; - MqttClientOptionsBuilderTlsParameters _tlsParameters; + MqttClientTlsOptions _tlsOptions; + + [Obsolete] MqttClientOptionsBuilderTlsParameters _tlsParameters; + MqttClientWebSocketOptions _webSocketOptions; public MqttClientOptions Build() @@ -28,11 +32,12 @@ public MqttClientOptions Build() throw new InvalidOperationException("A channel must be set."); } + var tlsOptions = _tlsOptions; if (_tlsParameters != null) { if (_tlsParameters?.UseTls == true) { - var tlsOptions = new MqttClientTlsOptions + tlsOptions = new MqttClientTlsOptions { UseTls = true, SslProtocol = _tlsParameters.SslProtocol, @@ -40,28 +45,23 @@ public MqttClientOptions Build() CertificateValidationHandler = _tlsParameters.CertificateValidationHandler, IgnoreCertificateChainErrors = _tlsParameters.IgnoreCertificateChainErrors, IgnoreCertificateRevocationErrors = _tlsParameters.IgnoreCertificateRevocationErrors, -#if WINDOWS_UWP - Certificates = _tlsParameters.Certificates?.Select(c => c.ToArray()).ToList(), -#else - Certificates = _tlsParameters.Certificates?.ToList(), -#endif - + ClientCertificatesProvider = _tlsParameters.CertificatesProvider, #if NETCOREAPP3_1_OR_GREATER ApplicationProtocols = _tlsParameters.ApplicationProtocols, #endif }; - - if (_tcpOptions != null) - { - _tcpOptions.TlsOptions = tlsOptions; - } - else if (_webSocketOptions != null) - { - _webSocketOptions.TlsOptions = tlsOptions; - } } } + if (_tcpOptions != null) + { + _tcpOptions.TlsOptions = tlsOptions; + } + else if (_webSocketOptions != null) + { + _webSocketOptions.TlsOptions = tlsOptions; + } + if (_proxyOptions != null) { if (_webSocketOptions == null) @@ -78,17 +78,7 @@ public MqttClientOptions Build() return _options; } - - /// - /// The client will not throw an exception when the MQTT server responses with a non success ACK packet. - /// This will become the default behavior in future versions of the library. - /// - public MqttClientOptionsBuilder WithoutThrowOnNonSuccessfulConnectResponse() - { - _options.ThrowOnNonSuccessfulConnectResponse = false; - return this; - } - + public MqttClientOptionsBuilder WithAuthentication(string method, byte[] data) { _options.AuthenticationMethod = method; @@ -120,6 +110,7 @@ public MqttClientOptionsBuilder WithClientId(string value) return this; } + [Obsolete("Use WithTcpServer(... configure) or WithWebSocketServer(... configure) instead.")] public MqttClientOptionsBuilder WithConnectionUri(Uri uri) { if (uri == null) @@ -136,12 +127,16 @@ public MqttClientOptionsBuilder WithConnectionUri(Uri uri) break; case "mqtts": - WithTcpServer(uri.Host, port).WithTls(); + WithTcpServer(uri.Host, port) + .WithTlsOptions( + o => + { + }); break; case "ws": case "wss": - WithWebSocketServer(uri.ToString()); + WithWebSocketServer(o => o.WithUri(uri.ToString())); break; default: @@ -159,6 +154,7 @@ public MqttClientOptionsBuilder WithConnectionUri(Uri uri) return this; } + [Obsolete("Use WithTcpServer(... configure) or WithWebSocketServer(... configure) instead.")] public MqttClientOptionsBuilder WithConnectionUri(string uri) { return WithConnectionUri(new Uri(uri, UriKind.Absolute)); @@ -222,6 +218,16 @@ public MqttClientOptionsBuilder WithoutPacketFragmentation() return this; } + /// + /// The client will not throw an exception when the MQTT server responses with a non success ACK packet. + /// This will become the default behavior in future versions of the library. + /// + public MqttClientOptionsBuilder WithoutThrowOnNonSuccessfulConnectResponse() + { + _options.ThrowOnNonSuccessfulConnectResponse = false; + return this; + } + public MqttClientOptionsBuilder WithProtocolVersion(MqttProtocolVersion value) { if (value == MqttProtocolVersion.Unknown) @@ -233,6 +239,7 @@ public MqttClientOptionsBuilder WithProtocolVersion(MqttProtocolVersion value) return this; } + [Obsolete("Use WithWebSocketServer(... configure) instead.")] public MqttClientOptionsBuilder WithProxy( string address, string username = null, @@ -254,6 +261,7 @@ public MqttClientOptionsBuilder WithProxy( return this; } + [Obsolete("Use WithWebSocketServer(... configure) instead.")] public MqttClientOptionsBuilder WithProxy(Action optionsBuilder) { if (optionsBuilder == null) @@ -324,17 +332,20 @@ public MqttClientOptionsBuilder WithTimeout(TimeSpan value) return this; } + [Obsolete("Use WithTlsOptions(... configure) instead.")] public MqttClientOptionsBuilder WithTls(MqttClientOptionsBuilderTlsParameters parameters) { _tlsParameters = parameters; return this; } + [Obsolete("Use WithTlsOptions(... configure) instead.")] public MqttClientOptionsBuilder WithTls() { return WithTls(new MqttClientOptionsBuilderTlsParameters { UseTls = true }); } + [Obsolete("Use WithTlsOptions(... configure) instead.")] public MqttClientOptionsBuilder WithTls(Action optionsBuilder) { if (optionsBuilder == null) @@ -351,6 +362,26 @@ public MqttClientOptionsBuilder WithTls(Action configure) + { + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + var builder = new MqttClientTlsOptionsBuilder(); + configure.Invoke(builder); + + _tlsOptions = builder.Build(); + return this; + } + public MqttClientOptionsBuilder WithTopicAliasMaximum(ushort topicAliasMaximum) { _options.TopicAliasMaximum = topicAliasMaximum; @@ -382,6 +413,7 @@ public MqttClientOptionsBuilder WithUserProperty(string name, string value) return this; } + [Obsolete("Use WithWebSocketServer(... configure) instead.")] public MqttClientOptionsBuilder WithWebSocketServer(string uri, MqttClientOptionsBuilderWebSocketParameters parameters = null) { _webSocketOptions = new MqttClientWebSocketOptions @@ -394,6 +426,21 @@ public MqttClientOptionsBuilder WithWebSocketServer(string uri, MqttClientOption return this; } + public MqttClientOptionsBuilder WithWebSocketServer(Action configure) + { + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + var webSocketOptionsBuilder = new MqttClientWebSocketOptionsBuilder(); + configure.Invoke(webSocketOptionsBuilder); + + _webSocketOptions = webSocketOptionsBuilder.Build(); + return this; + } + + [Obsolete("Use WithWebSocketServer(... configure) instead.")] public MqttClientOptionsBuilder WithWebSocketServer(Action optionsBuilder) { if (optionsBuilder == null) @@ -425,6 +472,12 @@ public MqttClientOptionsBuilder WithWillDelayInterval(uint willDelayInterval) return this; } + public MqttClientOptionsBuilder WithWillMessageExpiryInterval(uint willMessageExpiryInterval) + { + _options.WillMessageExpiryInterval = willMessageExpiryInterval; + return this; + } + public MqttClientOptionsBuilder WithWillPayload(byte[] willPayload) { _options.WillPayload = willPayload; @@ -478,12 +531,6 @@ public MqttClientOptionsBuilder WithWillRetain(bool willRetain = true) return this; } - public MqttClientOptionsBuilder WithWillMessageExpiryInterval(uint willMessageExpiryInterval) - { - _options.WillMessageExpiryInterval = willMessageExpiryInterval; - return this; - } - public MqttClientOptionsBuilder WithWillTopic(string willTopic) { _options.WillTopic = willTopic; diff --git a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilderTlsParameters.cs b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilderTlsParameters.cs index 5f5f7ce74..c79f3a4da 100644 --- a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilderTlsParameters.cs +++ b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilderTlsParameters.cs @@ -5,12 +5,14 @@ using System; using System.Collections.Generic; using System.Security.Authentication; -using System.Security.Cryptography.X509Certificates; namespace MQTTnet.Client { + [Obsolete("Use methods from MqttClientOptionsBuilder instead.")] public sealed class MqttClientOptionsBuilderTlsParameters { + IEnumerable _obsoleteCertificates; + public bool UseTls { get; set; } public Func CertificateValidationHandler { get; set; } @@ -24,7 +26,24 @@ public sealed class MqttClientOptionsBuilderTlsParameters #if WINDOWS_UWP public IEnumerable> Certificates { get; set; } #else - public IEnumerable Certificates { get; set; } + [Obsolete("Use CertificatesProvider instead.")] + public IEnumerable Certificates + { + get => _obsoleteCertificates; + set + { + _obsoleteCertificates = value; + + if (value == null) + { + CertificatesProvider = null; + } + else + { + CertificatesProvider = new DefaultMqttCertificatesProvider(value); + } + } + } #endif #if NETCOREAPP3_1_OR_GREATER @@ -36,5 +55,7 @@ public sealed class MqttClientOptionsBuilderTlsParameters public bool IgnoreCertificateChainErrors { get; set; } public bool IgnoreCertificateRevocationErrors { get; set; } + + public IMqttClientCertificatesProvider CertificatesProvider { get; set; } } } \ No newline at end of file diff --git a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilderWebSocketParameters.cs b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilderWebSocketParameters.cs index ba20a109b..d469a579c 100644 --- a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilderWebSocketParameters.cs +++ b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilderWebSocketParameters.cs @@ -2,11 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using System.Collections.Generic; using System.Net; namespace MQTTnet.Client { + [Obsolete("Use dedicated methods in MqttClientOptionsBuilder.")] public class MqttClientOptionsBuilderWebSocketParameters { public IDictionary RequestHeaders { get; set; } diff --git a/Source/MQTTnet/Client/Options/MqttClientTcpOptionsExtensions.cs b/Source/MQTTnet/Client/Options/MqttClientTcpOptionsExtensions.cs index f2141504b..fae6bc1e9 100644 --- a/Source/MQTTnet/Client/Options/MqttClientTcpOptionsExtensions.cs +++ b/Source/MQTTnet/Client/Options/MqttClientTcpOptionsExtensions.cs @@ -17,7 +17,7 @@ public static int GetPort(this MqttClientTcpOptions options) return options.Port.Value; } - return !options.TlsOptions.UseTls ? 1883 : 8883; + return !(options.TlsOptions?.UseTls ?? false) ? 1883 : 8883; } } } diff --git a/Source/MQTTnet/Client/Options/MqttClientTlsOptions.cs b/Source/MQTTnet/Client/Options/MqttClientTlsOptions.cs index 69ad9d9e4..75bd03471 100644 --- a/Source/MQTTnet/Client/Options/MqttClientTlsOptions.cs +++ b/Source/MQTTnet/Client/Options/MqttClientTlsOptions.cs @@ -24,11 +24,13 @@ public sealed class MqttClientTlsOptions public X509RevocationMode RevocationMode { get; set; } = X509RevocationMode.Online; -#if WINDOWS_UWP - public List Certificates { get; set; } -#else - public List Certificates { get; set; } -#endif + /// + /// Gets or sets the provider for certificates. + /// This provider gets called whenever the client wants to connect + /// with the server and requires certificates for authentication. + /// The implementation may return different certificates each time. + /// + public IMqttClientCertificatesProvider ClientCertificatesProvider { get; set; } #if NETCOREAPP3_1_OR_GREATER public List ApplicationProtocols { get; set; } diff --git a/Source/MQTTnet/Client/Options/MqttClientTlsOptionsBuilder.cs b/Source/MQTTnet/Client/Options/MqttClientTlsOptionsBuilder.cs new file mode 100644 index 000000000..a07a1111c --- /dev/null +++ b/Source/MQTTnet/Client/Options/MqttClientTlsOptionsBuilder.cs @@ -0,0 +1,136 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Security.Authentication; +using System.Security.Cryptography.X509Certificates; +#if NETCOREAPP3_1_OR_GREATER +using System.Net.Security; +#endif + +namespace MQTTnet.Client +{ + public sealed class MqttClientTlsOptionsBuilder + { + readonly MqttClientTlsOptions _tlsOptions = new MqttClientTlsOptions + { + // If someone used this builder the change is very very high that TLS + // should be actually used. + UseTls = true + }; + + public MqttClientTlsOptions Build() + { + return _tlsOptions; + } + + public MqttClientTlsOptionsBuilder UseTls(bool useTls = true) + { + _tlsOptions.UseTls = useTls; + return this; + } + + public MqttClientTlsOptionsBuilder WithAllowUntrustedCertificates(bool allowUntrustedCertificates = true) + { + _tlsOptions.AllowUntrustedCertificates = allowUntrustedCertificates; + return this; + } + + public MqttClientTlsOptionsBuilder WithCertificateValidationHandler(Func certificateValidationHandler) + { + if (certificateValidationHandler == null) + { + throw new ArgumentNullException(nameof(certificateValidationHandler)); + } + + _tlsOptions.CertificateValidationHandler = certificateValidationHandler; + return this; + } + + public MqttClientTlsOptionsBuilder WithClientCertificates(IEnumerable certificates) + { + if (certificates == null) + { + throw new ArgumentNullException(nameof(certificates)); + } + + _tlsOptions.ClientCertificatesProvider = new DefaultMqttCertificatesProvider(certificates); + return this; + } + + public MqttClientTlsOptionsBuilder WithClientCertificates(X509Certificate2Collection certificates) + { + if (certificates == null) + { + throw new ArgumentNullException(nameof(certificates)); + } + + _tlsOptions.ClientCertificatesProvider = new DefaultMqttCertificatesProvider(certificates); + return this; + } + + public MqttClientTlsOptionsBuilder WithClientCertificatesProvider(IMqttClientCertificatesProvider clientCertificatesProvider) + { + _tlsOptions.ClientCertificatesProvider = clientCertificatesProvider; + return this; + } + + public MqttClientTlsOptionsBuilder WithIgnoreCertificateChainErrors(bool ignoreCertificateChainErrors = true) + { + _tlsOptions.IgnoreCertificateChainErrors = ignoreCertificateChainErrors; + return this; + } + + public MqttClientTlsOptionsBuilder WithIgnoreCertificateRevocationErrors(bool ignoreCertificateRevocationErrors = true) + { + _tlsOptions.IgnoreCertificateRevocationErrors = ignoreCertificateRevocationErrors; + return this; + } + + public MqttClientTlsOptionsBuilder WithRevocationMode(X509RevocationMode revocationMode) + { + _tlsOptions.RevocationMode = revocationMode; + return this; + } + + public MqttClientTlsOptionsBuilder WithSslProtocols(SslProtocols sslProtocols) + { + _tlsOptions.SslProtocol = sslProtocols; + return this; + } + + public MqttClientTlsOptionsBuilder WithTargetHost(string targetHost) + { + _tlsOptions.TargetHost = targetHost; + return this; + } + +#if NETCOREAPP3_1_OR_GREATER + public MqttClientTlsOptionsBuilder WithAllowRenegotiation(bool allowRenegotiation = true) + { + _tlsOptions.AllowRenegotiation = allowRenegotiation; + return this; + } + + public MqttClientTlsOptionsBuilder WithApplicationProtocols(List applicationProtocols) + { + _tlsOptions.ApplicationProtocols = applicationProtocols; + return this; + } + + public MqttClientTlsOptionsBuilder WithCipherSuitesPolicy(CipherSuitesPolicy cipherSuitePolicy) + { + _tlsOptions.CipherSuitesPolicy = cipherSuitePolicy; + return this; + } + + public MqttClientTlsOptionsBuilder WithCipherSuitesPolicy(EncryptionPolicy encryptionPolicy) + { + _tlsOptions.EncryptionPolicy = encryptionPolicy; + return this; + } +#endif + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Client/Options/MqttClientWebSocketOptionsBuilder.cs b/Source/MQTTnet/Client/Options/MqttClientWebSocketOptionsBuilder.cs new file mode 100644 index 000000000..5b78406d6 --- /dev/null +++ b/Source/MQTTnet/Client/Options/MqttClientWebSocketOptionsBuilder.cs @@ -0,0 +1,85 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Net; + +namespace MQTTnet.Client +{ + public sealed class MqttClientWebSocketOptionsBuilder + { + readonly MqttClientWebSocketOptions _webSocketOptions = new MqttClientWebSocketOptions(); + + public MqttClientWebSocketOptions Build() + { + return _webSocketOptions; + } + + public MqttClientWebSocketOptionsBuilder WithCookieContainer(CookieContainer cookieContainer) + { + _webSocketOptions.CookieContainer = cookieContainer; + return this; + } + + public MqttClientWebSocketOptionsBuilder WithCookieContainer(ICredentials credentials) + { + _webSocketOptions.Credentials = credentials; + return this; + } + + public MqttClientWebSocketOptionsBuilder WithProxyOptions(MqttClientWebSocketProxyOptions proxyOptions) + { + _webSocketOptions.ProxyOptions = proxyOptions; + return this; + } + + public MqttClientWebSocketOptionsBuilder WithProxyOptions(Action configure) + { + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + var proxyOptionsBuilder = new MqttClientWebSocketProxyOptionsBuilder(); + configure.Invoke(proxyOptionsBuilder); + + _webSocketOptions.ProxyOptions = proxyOptionsBuilder.Build(); + return this; + } + + public MqttClientWebSocketOptionsBuilder WithRequestHeaders(IDictionary requestHeaders) + { + _webSocketOptions.RequestHeaders = requestHeaders; + return this; + } + + public MqttClientWebSocketOptionsBuilder WithSubProtocols(ICollection subProtocols) + { + _webSocketOptions.SubProtocols = subProtocols; + return this; + } + + public MqttClientWebSocketOptionsBuilder WithUri(string uri) + { + _webSocketOptions.Uri = uri; + return this; + } + +#if !NETSTANDARD1_3 + public MqttClientWebSocketOptionsBuilder WithKeepAliveInterval(TimeSpan keepAliveInterval) + { + _webSocketOptions.KeepAliveInterval = keepAliveInterval; + return this; + } +#endif +#if !WINDOWS_UWP && !NETSTANDARD1_3 + public MqttClientWebSocketOptionsBuilder WithUseDefaultCredentials(bool useDefaultCredentials = true) + { + _webSocketOptions.UseDefaultCredentials = useDefaultCredentials; + return this; + } +#endif + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Client/Options/MqttClientWebSocketProxyOptionsBuilder.cs b/Source/MQTTnet/Client/Options/MqttClientWebSocketProxyOptionsBuilder.cs new file mode 100644 index 000000000..3fa736a97 --- /dev/null +++ b/Source/MQTTnet/Client/Options/MqttClientWebSocketProxyOptionsBuilder.cs @@ -0,0 +1,67 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; + +namespace MQTTnet.Client +{ + public sealed class MqttClientWebSocketProxyOptionsBuilder + { + readonly MqttClientWebSocketProxyOptions _proxyOptions = new MqttClientWebSocketProxyOptions(); + + public MqttClientWebSocketProxyOptionsBuilder WithAddress(string address) + { + _proxyOptions.Address = address; + return this; + } + + public MqttClientWebSocketProxyOptionsBuilder WithUsername(string username) + { + _proxyOptions.Username = username; + return this; + } + + public MqttClientWebSocketProxyOptionsBuilder WithPassword(string password) + { + _proxyOptions.Password = password; + return this; + } + + public MqttClientWebSocketProxyOptionsBuilder WithDomain(string domain) + { + _proxyOptions.Domain = domain; + return this; + } + + public MqttClientWebSocketProxyOptionsBuilder WithBypassOnLocal(bool bypassOnLocal = true) + { + _proxyOptions.BypassOnLocal = bypassOnLocal; + return this; + } + + public MqttClientWebSocketProxyOptionsBuilder WithBypassList(string[] bypassList) + { + _proxyOptions.BypassList = bypassList; + return this; + } + + public MqttClientWebSocketProxyOptionsBuilder WithBypassList(IEnumerable bypassList) + { + _proxyOptions.BypassList = bypassList?.ToArray(); + return this; + } + + public MqttClientWebSocketProxyOptionsBuilder WithUseDefaultCredentials(bool useDefaultCredentials = true) + { + _proxyOptions.UseDefaultCredentials = useDefaultCredentials; + return this; + } + + public MqttClientWebSocketProxyOptions Build() + { + return _proxyOptions; + } + } +} \ No newline at end of file diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs index dbca74657..d26ebe6a0 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.Uwp.cs @@ -6,8 +6,6 @@ using System; using System.Collections.Generic; using System.IO; -using System.Linq; -using System.Runtime.InteropServices.WindowsRuntime; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Threading; @@ -18,6 +16,7 @@ using MQTTnet.Channel; using MQTTnet.Client; using MQTTnet.Server; +using System.Runtime.InteropServices.WindowsRuntime; namespace MQTTnet.Implementations { @@ -126,17 +125,19 @@ public void Dispose() private static Certificate LoadCertificate(IMqttClientChannelOptions options) { - if (options.TlsOptions.Certificates == null || !options.TlsOptions.Certificates.Any()) + var certificates = options.TlsOptions?.ClientCertificatesProvider?.GetCertificates(); + + if (certificates == null || certificates.Count == 0) { return null; } - if (options.TlsOptions.Certificates.Count > 1) + if (certificates.Count > 1) { throw new NotSupportedException("Only one client certificate is supported when using 'uap10.0'."); } - return new Certificate(options.TlsOptions.Certificates.First().AsBuffer()); + return new Certificate(certificates[0].Export(X509ContentType.Cert).AsBuffer()); } private IEnumerable ResolveIgnorableServerCertificateErrors() diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index fb1a63f89..8e4de584f 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -3,9 +3,6 @@ // See the LICENSE file in the project root for more information. #if !WINDOWS_UWP -using MQTTnet.Channel; -using MQTTnet.Client; -using MQTTnet.Exceptions; using System; using System.IO; using System.Net.Security; @@ -14,6 +11,9 @@ using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; +using MQTTnet.Channel; +using MQTTnet.Client; +using MQTTnet.Exceptions; using MQTTnet.Internal; namespace MQTTnet.Implementations @@ -73,7 +73,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken) { socket.Bind(_tcpOptions.LocalEndpoint); } - + socket.ReceiveBufferSize = _tcpOptions.BufferSize; socket.SendBufferSize = _tcpOptions.BufferSize; socket.SendTimeout = (int)_clientOptions.Timeout.TotalMilliseconds; @@ -105,7 +105,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken) { targetHost = _tcpOptions.Server; } - + var sslStream = new SslStream(networkStream, false, InternalUserCertificateValidationCallback); try { @@ -298,18 +298,7 @@ bool InternalUserCertificateValidationCallback(object sender, X509Certificate x5 X509CertificateCollection LoadCertificates() { - if (_tcpOptions.TlsOptions.Certificates == null) - { - return null; - } - - var certificates = new X509CertificateCollection(); - foreach (var certificate in _tcpOptions.TlsOptions.Certificates) - { - certificates.Add(certificate); - } - - return certificates; + return _tcpOptions.TlsOptions.ClientCertificatesProvider?.GetCertificates(); } } } diff --git a/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs b/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs index be31c2a51..fdddcf6e0 100644 --- a/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs +++ b/Source/MQTTnet/Implementations/MqttWebSocketChannel.cs @@ -206,16 +206,12 @@ void SetupClientWebSocket(ClientWebSocket clientWebSocket) clientWebSocket.Options.Cookies = _options.CookieContainer; } - if (_options.TlsOptions?.UseTls == true && _options.TlsOptions?.Certificates != null) + if (_options.TlsOptions?.UseTls == true) { - clientWebSocket.Options.ClientCertificates = new X509CertificateCollection(); - foreach (var certificate in _options.TlsOptions.Certificates) + var certificates = _options.TlsOptions?.ClientCertificatesProvider?.GetCertificates(); + if (certificates?.Count > 0) { -#if WINDOWS_UWP - clientWebSocket.Options.ClientCertificates.Add(new X509Certificate(certificate)); -#else - clientWebSocket.Options.ClientCertificates.Add(certificate); -#endif + clientWebSocket.Options.ClientCertificates = certificates; } }