From 3f213ff442eacbcc1f10faf1138bcc741f7c6467 Mon Sep 17 00:00:00 2001 From: "Jonas R. Hansen" Date: Sun, 24 Nov 2024 19:57:58 +0100 Subject: [PATCH] Always setup options --- src/ToMqttNet/MqttConnectionOptions.cs | 17 +++-- src/ToMqttNet/MqttConnectionService.cs | 71 +++++++++++++---- ...ttConnectionServiceCollectionExtensions.cs | 1 + src/ToMqttNet/ToMqttNet.csproj | 2 +- .../WatchingMqttCertificateProvider.cs | 76 +++++++++++++++++++ .../MqttConnectionServiceTests.cs | 15 ++-- .../ToMqttNet.Test.Unit.csproj | 1 + 7 files changed, 156 insertions(+), 27 deletions(-) create mode 100644 src/ToMqttNet/WatchingMqttCertificateProvider.cs diff --git a/src/ToMqttNet/MqttConnectionOptions.cs b/src/ToMqttNet/MqttConnectionOptions.cs index 2d99e07..afc6054 100644 --- a/src/ToMqttNet/MqttConnectionOptions.cs +++ b/src/ToMqttNet/MqttConnectionOptions.cs @@ -1,15 +1,20 @@ -using HomeAssistantDiscoveryNet; -using MQTTnet.Client; using System.ComponentModel.DataAnnotations; +using HomeAssistantDiscoveryNet; namespace ToMqttNet; public class MqttConnectionOptions { - [Required] - public string NodeId { get; set; } = null!; + public const string Section = "MqttConnection"; - public MqttClientOptions ClientOptions { get; set; } = new MqttClientOptions { }; + public int? Port { get; set; } + public bool UseTls { get; set; } + [Required] + public string NodeId { get; set; } = null!; + public string? Server { get; set; } + public string? CaCrt { get; set; } + public string? ClientCrt { get; set; } + public string? ClientKey { get; set; } - public MqttDiscoveryConfigOrigin? OriginConfig { get; set; } + public MqttDiscoveryConfigOrigin? OriginConfig { get; set; } } diff --git a/src/ToMqttNet/MqttConnectionService.cs b/src/ToMqttNet/MqttConnectionService.cs index 840a7c9..23a84a0 100644 --- a/src/ToMqttNet/MqttConnectionService.cs +++ b/src/ToMqttNet/MqttConnectionService.cs @@ -6,6 +6,8 @@ using MQTTnet.Client; using MQTTnet.Extensions.ManagedClient; using MQTTnet.Packets; +using System.Net; +using System.Security.Cryptography.X509Certificates; using System.Text; namespace ToMqttNet; @@ -14,11 +16,13 @@ public class MqttConnectionService( ILogger logger, IOptions mqttOptions, [FromKeyedServices(typeof(MqttConnectionService))] IManagedMqttClient managedMqttClient, - MqttCounters counters) : BackgroundService, IMqttConnectionService + MqttCounters counters, + IServiceProvider serviceProvider) : BackgroundService, IMqttConnectionService { private readonly ILogger _logger = logger; private readonly MqttCounters _counters = counters; private readonly string _instanceId = Guid.NewGuid().ToString(); + private WatchingMqttCertificateProvider? _certificateWatcher; public MqttConnectionOptions MqttOptions { get; } = mqttOptions.Value; private readonly IManagedMqttClient _mqttClient = managedMqttClient; @@ -30,22 +34,18 @@ public class MqttConnectionService( protected override async Task ExecuteAsync(CancellationToken stoppingToken) { _logger.LogInformation("Executing {backgroundService}", GetType().FullName); - var options = MqttOptions.ClientOptions; + var options = new MqttClientOptions { }; - if(string.IsNullOrEmpty(options.ClientId)) - { - options.ClientId = MqttOptions.NodeId + "-" + _instanceId; - } + if (string.IsNullOrEmpty(options.ClientId)) + { + options.ClientId = MqttOptions.NodeId + "-" + _instanceId; + } options.WillPayload = Encoding.UTF8.GetBytes("0"); options.WillTopic = $"{MqttOptions.NodeId}/connected"; options.WillRetain = true; - options.ChannelOptions ??= new MqttClientTcpOptions - { - Server = "mosquitto", - Port = 1883 - }; + options.ChannelOptions = BuildChannelOptions(); var optionsBuilder = new ManagedMqttClientOptionsBuilder() .WithAutoReconnectDelay(TimeSpan.FromSeconds(5)) @@ -53,12 +53,14 @@ protected override async Task ExecuteAsync(CancellationToken stoppingToken) _counters.SetPendingMessages(() => _mqttClient.PendingApplicationMessagesCount); - _mqttClient.ConnectionStateChangedAsync += (evnt) => { + _mqttClient.ConnectionStateChangedAsync += (evnt) => + { _counters.SetConnections(_mqttClient.IsConnected ? 1 : 0); return Task.CompletedTask; }; - _mqttClient.ConnectingFailedAsync += (evnt) => { + _mqttClient.ConnectingFailedAsync += (evnt) => + { _logger.LogWarning(evnt.Exception, "Connection to mqtt failed"); _counters.SetConnections(0); @@ -90,12 +92,13 @@ await _mqttClient.EnqueueAsync( { try { - if(_logger.IsEnabled(LogLevel.Trace)) + if (_logger.IsEnabled(LogLevel.Trace)) { _logger.LogTrace("{topic}: {message}", evnt.ApplicationMessage.Topic, evnt.ApplicationMessage.ConvertPayloadToString()); } OnApplicationMessageReceived?.Invoke(this, evnt); - }catch(Exception e) + } + catch (Exception e) { _logger.LogWarning(e, "Failed to handle message to topic {topic}", evnt.ApplicationMessage.Topic); _counters.IncreaseMessagesHandled(success: false); @@ -124,4 +127,42 @@ public Task UnsubscribeAsync(params string[] topics) { return _mqttClient!.UnsubscribeAsync(topics); } + + private IMqttClientChannelOptions BuildChannelOptions() + { + var tcpOptions = new MqttClientTcpOptions + { + RemoteEndpoint = new DnsEndPoint(MqttOptions.Server ?? "mosquitto", MqttOptions.Port ?? 1883), + }; + + if (MqttOptions.UseTls) + { + _certificateWatcher = ActivatorUtilities.CreateInstance(serviceProvider); + + tcpOptions.TlsOptions = new MqttClientTlsOptions + { + UseTls = true, + SslProtocol = System.Security.Authentication.SslProtocols.Tls12, + ClientCertificatesProvider = _certificateWatcher, + CertificateValidationHandler = (certContext) => + { + X509Chain chain = new(); + chain.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck; + chain.ChainPolicy.RevocationFlag = X509RevocationFlag.ExcludeRoot; + chain.ChainPolicy.VerificationFlags = X509VerificationFlags.NoFlag; + chain.ChainPolicy.VerificationTime = DateTime.Now; + chain.ChainPolicy.UrlRetrievalTimeout = new TimeSpan(0, 0, 0); + chain.ChainPolicy.CustomTrustStore.Add(_certificateWatcher.CaCertificate!); + chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust; + + // convert provided X509Certificate to X509Certificate2 + var x5092 = new X509Certificate2(certContext.Certificate); + + return chain.Build(x5092); + } + }; + } + + return tcpOptions; + } } diff --git a/src/ToMqttNet/MqttConnectionServiceCollectionExtensions.cs b/src/ToMqttNet/MqttConnectionServiceCollectionExtensions.cs index 94b8c32..736dd98 100644 --- a/src/ToMqttNet/MqttConnectionServiceCollectionExtensions.cs +++ b/src/ToMqttNet/MqttConnectionServiceCollectionExtensions.cs @@ -19,6 +19,7 @@ public static OptionsBuilder AddMqttConnection(this IServ services.AddSingleton(x => x.GetRequiredService()); services.AddHostedService(x => x.GetRequiredService()); services.AddKeyedSingleton(typeof(MqttConnectionService), (services, key) => new MqttFactory().CreateManagedMqttClient()); + services.AddOptions().BindConfiguration(MqttConnectionOptions.Section); return services.AddOptions(); } diff --git a/src/ToMqttNet/ToMqttNet.csproj b/src/ToMqttNet/ToMqttNet.csproj index 4132404..bf4dadd 100644 --- a/src/ToMqttNet/ToMqttNet.csproj +++ b/src/ToMqttNet/ToMqttNet.csproj @@ -20,7 +20,7 @@ - + diff --git a/src/ToMqttNet/WatchingMqttCertificateProvider.cs b/src/ToMqttNet/WatchingMqttCertificateProvider.cs new file mode 100644 index 0000000..78162aa --- /dev/null +++ b/src/ToMqttNet/WatchingMqttCertificateProvider.cs @@ -0,0 +1,76 @@ +using System.Security.Cryptography.X509Certificates; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using MQTTnet.Client; + +namespace ToMqttNet; + +public class WatchingMqttCertificateProvider : IMqttClientCertificatesProvider +{ + private readonly List _watchers = []; + private readonly MqttConnectionOptions _options; + private readonly ILogger _logger; + + public X509Certificate2? CaCertificate { get; private set; } + public X509Certificate2Collection ClientCertificates { get; private set; } = new(); + + public WatchingMqttCertificateProvider(ILogger logger, IOptions options) + { + _options = options.Value; + _logger = logger; + var certDirectories = new string?[] { _options.CaCrt, _options.ClientCrt, _options.ClientKey } + .Where(x => x != null) + .Select(x => Path.GetDirectoryName(x)!) + .Distinct() + .ToList(); + + foreach (var directory in certDirectories) + { + var watcher = new FileSystemWatcher(directory); + _watchers.Add(watcher); + + watcher.Changed += OnCertificateChanged; + watcher.EnableRaisingEvents = true; + } + + LoadCertificates(); + } + + private void LoadCertificates() + { + _logger.LogInformation("Loading certificates"); + try + { + ClientCertificates.Clear(); + if (_options.ClientCrt != null && _options.ClientKey != null) + { + var clientCert = X509Certificate2.CreateFromPemFile(_options.ClientCrt, _options.ClientKey); + ClientCertificates.Add(clientCert); + _logger.LogInformation("Loaded Client Certificate {name} from {certPath}, {keyPath}", clientCert.Thumbprint, _options.ClientCrt, _options.ClientKey); + } + + if (_options.CaCrt != null) + { + CaCertificate = new X509Certificate2(_options.CaCrt); + ClientCertificates.Add(CaCertificate); + _logger.LogInformation("Loaded CA Certificate {name} from {path}", CaCertificate.Thumbprint, _options.CaCrt); + } + } + catch (Exception ex) + { + _logger.LogError(ex, "Failed to load certificates"); + throw; + } + _logger.LogInformation("Certificates loaded"); + } + + private void OnCertificateChanged(object sender, FileSystemEventArgs e) + { + LoadCertificates(); + } + + public X509CertificateCollection GetCertificates() + { + return ClientCertificates; + } +} \ No newline at end of file diff --git a/test/ToMqttNet.Test.Unit/MqttConnectionServiceTests.cs b/test/ToMqttNet.Test.Unit/MqttConnectionServiceTests.cs index eff2599..64fcd44 100644 --- a/test/ToMqttNet.Test.Unit/MqttConnectionServiceTests.cs +++ b/test/ToMqttNet.Test.Unit/MqttConnectionServiceTests.cs @@ -1,4 +1,5 @@ -using Microsoft.Extensions.Options; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; using MQTTnet; using MQTTnet.Client; using MQTTnet.Extensions.ManagedClient; @@ -21,13 +22,17 @@ public class MqttConnectionServiceTests public MqttConnectionServiceTests(ITestOutputHelper testOutputHelper) { + var serviceCollection = new ServiceCollection(); + var provider = serviceCollection.BuildServiceProvider(); + _meterFactoryStub = new MeterFactoryStub(); _clientStub = new MqttClientStub(); _sut = new MqttConnectionService( testOutputHelper.CreateLogger(), Options.Create(new MqttConnectionOptions()), _clientStub, - new MqttCounters(_meterFactoryStub)); + new MqttCounters(_meterFactoryStub), + provider); } [Fact] @@ -70,7 +75,7 @@ public async Task ShouldIncreaseMessagesSent() var messagesSent = -1L; listener.InstrumentPublished = (instrument, listener) => { - if(instrument.Meter.Name == "ToMqttNet") + if (instrument.Meter.Name == "ToMqttNet") { listener.EnableMeasurementEvents(instrument); } @@ -141,7 +146,7 @@ public Task CallDisconnectedAsync(MqttClientDisconnectedEventArgs args) public List EnqueuedMessage = []; public List Subscriptions = []; - public void Dispose(){} + public void Dispose() { } public Task EnqueueAsync(MqttApplicationMessage applicationMessage) { @@ -198,5 +203,5 @@ public Meter Create(MeterOptions options) return new Meter(options); } - public void Dispose(){} + public void Dispose() { } } \ No newline at end of file diff --git a/test/ToMqttNet.Test.Unit/ToMqttNet.Test.Unit.csproj b/test/ToMqttNet.Test.Unit/ToMqttNet.Test.Unit.csproj index ff7a3da..e5e0142 100644 --- a/test/ToMqttNet.Test.Unit/ToMqttNet.Test.Unit.csproj +++ b/test/ToMqttNet.Test.Unit/ToMqttNet.Test.Unit.csproj @@ -8,6 +8,7 @@ + runtime; build; native; contentfiles; analyzers; buildtransitive