Skip to content

Commit

Permalink
Always setup options
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasMH committed Nov 24, 2024
1 parent 9037566 commit 3f213ff
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 27 deletions.
17 changes: 11 additions & 6 deletions src/ToMqttNet/MqttConnectionOptions.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
using HomeAssistantDiscoveryNet;
using MQTTnet.Client;
using System.ComponentModel.DataAnnotations;
using HomeAssistantDiscoveryNet;

namespace ToMqttNet;

public class MqttConnectionOptions
{
[Required]
public string NodeId { get; set; } = null!;
public const string Section = "MqttConnection";

public MqttClientOptions ClientOptions { get; set; } = new MqttClientOptions { };
public int? Port { get; set; }
public bool UseTls { get; set; }
[Required]
public string NodeId { get; set; } = null!;
public string? Server { get; set; }
public string? CaCrt { get; set; }
public string? ClientCrt { get; set; }
public string? ClientKey { get; set; }

public MqttDiscoveryConfigOrigin? OriginConfig { get; set; }
public MqttDiscoveryConfigOrigin? OriginConfig { get; set; }
}
71 changes: 56 additions & 15 deletions src/ToMqttNet/MqttConnectionService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
using MQTTnet.Client;
using MQTTnet.Extensions.ManagedClient;
using MQTTnet.Packets;
using System.Net;
using System.Security.Cryptography.X509Certificates;
using System.Text;

namespace ToMqttNet;
Expand All @@ -14,11 +16,13 @@ public class MqttConnectionService(
ILogger<MqttConnectionService> logger,
IOptions<MqttConnectionOptions> mqttOptions,
[FromKeyedServices(typeof(MqttConnectionService))] IManagedMqttClient managedMqttClient,
MqttCounters counters) : BackgroundService, IMqttConnectionService
MqttCounters counters,
IServiceProvider serviceProvider) : BackgroundService, IMqttConnectionService
{
private readonly ILogger<MqttConnectionService> _logger = logger;
private readonly MqttCounters _counters = counters;
private readonly string _instanceId = Guid.NewGuid().ToString();
private WatchingMqttCertificateProvider? _certificateWatcher;

public MqttConnectionOptions MqttOptions { get; } = mqttOptions.Value;
private readonly IManagedMqttClient _mqttClient = managedMqttClient;
Expand All @@ -30,35 +34,33 @@ public class MqttConnectionService(
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
_logger.LogInformation("Executing {backgroundService}", GetType().FullName);
var options = MqttOptions.ClientOptions;
var options = new MqttClientOptions { };

if(string.IsNullOrEmpty(options.ClientId))
{
options.ClientId = MqttOptions.NodeId + "-" + _instanceId;
}
if (string.IsNullOrEmpty(options.ClientId))
{
options.ClientId = MqttOptions.NodeId + "-" + _instanceId;
}

options.WillPayload = Encoding.UTF8.GetBytes("0");
options.WillTopic = $"{MqttOptions.NodeId}/connected";
options.WillRetain = true;

options.ChannelOptions ??= new MqttClientTcpOptions
{
Server = "mosquitto",
Port = 1883
};
options.ChannelOptions = BuildChannelOptions();

var optionsBuilder = new ManagedMqttClientOptionsBuilder()
.WithAutoReconnectDelay(TimeSpan.FromSeconds(5))
.WithClientOptions(options);

_counters.SetPendingMessages(() => _mqttClient.PendingApplicationMessagesCount);

_mqttClient.ConnectionStateChangedAsync += (evnt) => {
_mqttClient.ConnectionStateChangedAsync += (evnt) =>
{
_counters.SetConnections(_mqttClient.IsConnected ? 1 : 0);
return Task.CompletedTask;
};

_mqttClient.ConnectingFailedAsync += (evnt) => {
_mqttClient.ConnectingFailedAsync += (evnt) =>
{
_logger.LogWarning(evnt.Exception, "Connection to mqtt failed");

_counters.SetConnections(0);
Expand Down Expand Up @@ -90,12 +92,13 @@ await _mqttClient.EnqueueAsync(
{
try
{
if(_logger.IsEnabled(LogLevel.Trace))
if (_logger.IsEnabled(LogLevel.Trace))
{
_logger.LogTrace("{topic}: {message}", evnt.ApplicationMessage.Topic, evnt.ApplicationMessage.ConvertPayloadToString());
}
OnApplicationMessageReceived?.Invoke(this, evnt);
}catch(Exception e)
}
catch (Exception e)
{
_logger.LogWarning(e, "Failed to handle message to topic {topic}", evnt.ApplicationMessage.Topic);
_counters.IncreaseMessagesHandled(success: false);
Expand Down Expand Up @@ -124,4 +127,42 @@ public Task UnsubscribeAsync(params string[] topics)
{
return _mqttClient!.UnsubscribeAsync(topics);
}

private IMqttClientChannelOptions BuildChannelOptions()
{
var tcpOptions = new MqttClientTcpOptions
{
RemoteEndpoint = new DnsEndPoint(MqttOptions.Server ?? "mosquitto", MqttOptions.Port ?? 1883),
};

if (MqttOptions.UseTls)
{
_certificateWatcher = ActivatorUtilities.CreateInstance<WatchingMqttCertificateProvider>(serviceProvider);

tcpOptions.TlsOptions = new MqttClientTlsOptions
{
UseTls = true,
SslProtocol = System.Security.Authentication.SslProtocols.Tls12,
ClientCertificatesProvider = _certificateWatcher,
CertificateValidationHandler = (certContext) =>
{
X509Chain chain = new();
chain.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck;
chain.ChainPolicy.RevocationFlag = X509RevocationFlag.ExcludeRoot;
chain.ChainPolicy.VerificationFlags = X509VerificationFlags.NoFlag;
chain.ChainPolicy.VerificationTime = DateTime.Now;
chain.ChainPolicy.UrlRetrievalTimeout = new TimeSpan(0, 0, 0);
chain.ChainPolicy.CustomTrustStore.Add(_certificateWatcher.CaCertificate!);
chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;

// convert provided X509Certificate to X509Certificate2
var x5092 = new X509Certificate2(certContext.Certificate);

return chain.Build(x5092);
}
};
}

return tcpOptions;
}
}
1 change: 1 addition & 0 deletions src/ToMqttNet/MqttConnectionServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public static OptionsBuilder<MqttConnectionOptions> AddMqttConnection(this IServ
services.AddSingleton<IMqttConnectionService>(x => x.GetRequiredService<MqttConnectionService>());
services.AddHostedService(x => x.GetRequiredService<MqttConnectionService>());
services.AddKeyedSingleton(typeof(MqttConnectionService), (services, key) => new MqttFactory().CreateManagedMqttClient());
services.AddOptions<MqttConnectionOptions>().BindConfiguration(MqttConnectionOptions.Section);

Check warning on line 22 in src/ToMqttNet/MqttConnectionServiceCollectionExtensions.cs

View workflow job for this annotation

GitHub Actions / build

Using member 'Microsoft.Extensions.DependencyInjection.OptionsBuilderConfigurationExtensions.BindConfiguration<TOptions>(OptionsBuilder<TOptions>, String, Action<BinderOptions>)' which has 'RequiresUnreferencedCodeAttribute' can break functionality when trimming application code. TOptions's dependent types may have their members trimmed. Ensure all required members are preserved.

Check warning on line 22 in src/ToMqttNet/MqttConnectionServiceCollectionExtensions.cs

View workflow job for this annotation

GitHub Actions / build

Using member 'Microsoft.Extensions.DependencyInjection.OptionsBuilderConfigurationExtensions.BindConfiguration<TOptions>(OptionsBuilder<TOptions>, String, Action<BinderOptions>)' which has 'RequiresDynamicCodeAttribute' can break functionality when AOT compiling. Binding strongly typed objects to configuration values may require generating dynamic code at runtime.

return services.AddOptions<MqttConnectionOptions>();
}
Expand Down
2 changes: 1 addition & 1 deletion src/ToMqttNet/ToMqttNet.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Hosting.Abstractions" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Options.DataAnnotations" Version="9.0.0" />
<PackageReference Include="Microsoft.Extensions.Options.ConfigurationExtensions" Version="9.0.0" />
<PackageReference Include="MQTTnet" Version="4.3.7.1207" />
<PackageReference Include="MQTTnet.Extensions.ManagedClient" Version="4.3.7.1207" />
</ItemGroup>
Expand Down
76 changes: 76 additions & 0 deletions src/ToMqttNet/WatchingMqttCertificateProvider.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
using System.Security.Cryptography.X509Certificates;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using MQTTnet.Client;

namespace ToMqttNet;

public class WatchingMqttCertificateProvider : IMqttClientCertificatesProvider
{
private readonly List<FileSystemWatcher> _watchers = [];
private readonly MqttConnectionOptions _options;
private readonly ILogger<WatchingMqttCertificateProvider> _logger;

public X509Certificate2? CaCertificate { get; private set; }
public X509Certificate2Collection ClientCertificates { get; private set; } = new();

public WatchingMqttCertificateProvider(ILogger<WatchingMqttCertificateProvider> logger, IOptions<MqttConnectionOptions> options)
{
_options = options.Value;
_logger = logger;
var certDirectories = new string?[] { _options.CaCrt, _options.ClientCrt, _options.ClientKey }
.Where(x => x != null)
.Select(x => Path.GetDirectoryName(x)!)
.Distinct()
.ToList();

foreach (var directory in certDirectories)
{
var watcher = new FileSystemWatcher(directory);
_watchers.Add(watcher);

watcher.Changed += OnCertificateChanged;
watcher.EnableRaisingEvents = true;
}

LoadCertificates();
}

private void LoadCertificates()
{
_logger.LogInformation("Loading certificates");
try
{
ClientCertificates.Clear();
if (_options.ClientCrt != null && _options.ClientKey != null)
{
var clientCert = X509Certificate2.CreateFromPemFile(_options.ClientCrt, _options.ClientKey);
ClientCertificates.Add(clientCert);
_logger.LogInformation("Loaded Client Certificate {name} from {certPath}, {keyPath}", clientCert.Thumbprint, _options.ClientCrt, _options.ClientKey);
}

if (_options.CaCrt != null)
{
CaCertificate = new X509Certificate2(_options.CaCrt);

Check warning on line 54 in src/ToMqttNet/WatchingMqttCertificateProvider.cs

View workflow job for this annotation

GitHub Actions / build

'X509Certificate2.X509Certificate2(string)' is obsolete: 'Loading certificate data through the constructor or Import is obsolete. Use X509CertificateLoader instead to load certificates.' (https://aka.ms/dotnet-warnings/SYSLIB0057)
ClientCertificates.Add(CaCertificate);
_logger.LogInformation("Loaded CA Certificate {name} from {path}", CaCertificate.Thumbprint, _options.CaCrt);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to load certificates");
throw;
}
_logger.LogInformation("Certificates loaded");
}

private void OnCertificateChanged(object sender, FileSystemEventArgs e)
{
LoadCertificates();
}

public X509CertificateCollection GetCertificates()
{
return ClientCertificates;
}
}
15 changes: 10 additions & 5 deletions test/ToMqttNet.Test.Unit/MqttConnectionServiceTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.Extensions.Options;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;
using MQTTnet;
using MQTTnet.Client;
using MQTTnet.Extensions.ManagedClient;
Expand All @@ -21,13 +22,17 @@ public class MqttConnectionServiceTests

public MqttConnectionServiceTests(ITestOutputHelper testOutputHelper)
{
var serviceCollection = new ServiceCollection();
var provider = serviceCollection.BuildServiceProvider();

_meterFactoryStub = new MeterFactoryStub();
_clientStub = new MqttClientStub();
_sut = new MqttConnectionService(
testOutputHelper.CreateLogger<MqttConnectionService>(),
Options.Create(new MqttConnectionOptions()),
_clientStub,
new MqttCounters(_meterFactoryStub));
new MqttCounters(_meterFactoryStub),
provider);
}

[Fact]
Expand Down Expand Up @@ -70,7 +75,7 @@ public async Task ShouldIncreaseMessagesSent()
var messagesSent = -1L;
listener.InstrumentPublished = (instrument, listener) =>
{
if(instrument.Meter.Name == "ToMqttNet")
if (instrument.Meter.Name == "ToMqttNet")
{
listener.EnableMeasurementEvents(instrument);
}
Expand Down Expand Up @@ -141,7 +146,7 @@ public Task CallDisconnectedAsync(MqttClientDisconnectedEventArgs args)
public List<MqttApplicationMessage> EnqueuedMessage = [];
public List<MqttTopicFilter> Subscriptions = [];

public void Dispose(){}
public void Dispose() { }

public Task EnqueueAsync(MqttApplicationMessage applicationMessage)
{
Expand Down Expand Up @@ -198,5 +203,5 @@ public Meter Create(MeterOptions options)
return new Meter(options);
}

public void Dispose(){}
public void Dispose() { }
}
1 change: 1 addition & 0 deletions test/ToMqttNet.Test.Unit/ToMqttNet.Test.Unit.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.12.0" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="9.0.0" />
<PackageReference Include="xunit" Version="2.9.2" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.8.2">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
Expand Down

0 comments on commit 3f213ff

Please sign in to comment.