Skip to content

Commit

Permalink
CSHARP-5017: Retry KMS requests on transient errors (#1541)
Browse files Browse the repository at this point in the history
  • Loading branch information
papafe authored Jan 31, 2025
1 parent c2de507 commit e3f3943
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 40 deletions.
2 changes: 2 additions & 0 deletions src/MongoDB.Driver.Encryption/CryptClientFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ public static CryptClient Create(CryptOptions options)

Library.mongocrypt_setopt_use_need_kms_credentials_state(handle);

Library.mongocrypt_setopt_retry_kms(handle, true);

Library.mongocrypt_init(handle);

if (options.IsCryptSharedLibRequired)
Expand Down
14 changes: 4 additions & 10 deletions src/MongoDB.Driver.Encryption/CryptContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,12 @@ public Binary FinalizeForEncryption()
}

/// <summary>
/// Gets a collection of KMS message requests to make
/// Gets the next KMS message request
/// </summary>
/// <returns>Collection of KMS Messages</returns>
public KmsRequestCollection GetKmsMessageRequests()
public KmsRequest GetNextKmsMessageRequest()
{
var requests = new List<KmsRequest>();
for (IntPtr request = Library.mongocrypt_ctx_next_kms_ctx(_handle); request != IntPtr.Zero; request = Library.mongocrypt_ctx_next_kms_ctx(_handle))
{
requests.Add(new KmsRequest(request));
}

return new KmsRequestCollection(requests, this);
var request = Library.mongocrypt_ctx_next_kms_ctx(_handle);
return request == IntPtr.Zero ? null : new KmsRequest(request);
}

/// <summary>
Expand Down
11 changes: 11 additions & 0 deletions src/MongoDB.Driver.Encryption/KmsRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ public string KmsProvider
}
}

/// <summary>
/// The number of milliseconds to wait before sending this request.
/// </summary>
public int Sleep => (int)(Library.mongocrypt_kms_ctx_usleep(_id) / 1000);

/// <summary>
/// Gets the message to send to KMS.
/// </summary>
Expand All @@ -88,6 +93,12 @@ public Binary GetMessage()
return binary;
}

/// <summary>
/// Indicates a network-level failure.
/// </summary>
/// <returns>A boolean indicating whether the failed request may be retried.</returns>
public bool Fail() => Library.mongocrypt_kms_ctx_fail(_id);

/// <summary>
/// Feeds the response back to the libmongocrypt
/// </summary>
Expand Down
76 changes: 58 additions & 18 deletions src/MongoDB.Driver.Encryption/LibMongoCryptControllerBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -211,22 +211,20 @@ private SslStreamSettings GetTlsStreamSettings(string kmsProvider)

private void ProcessNeedKmsState(CryptContext context, CancellationToken cancellationToken)
{
var requests = context.GetKmsMessageRequests();
foreach (var request in requests)
while (context.GetNextKmsMessageRequest() is { } request)
{
SendKmsRequest(request, cancellationToken);
}
requests.MarkDone();
context.MarkKmsDone();
}

private async Task ProcessNeedKmsStateAsync(CryptContext context, CancellationToken cancellationToken)
{
var requests = context.GetKmsMessageRequests();
foreach (var request in requests)
while (context.GetNextKmsMessageRequest() is { } request)
{
await SendKmsRequestAsync(request, cancellationToken).ConfigureAwait(false);
}
requests.MarkDone();
context.MarkKmsDone();
}

private void ProcessNeedMongoKeysState(CryptContext context, CancellationToken cancellationToken)
Expand Down Expand Up @@ -278,48 +276,90 @@ private static byte[] ProcessReadyState(CryptContext context)

private void SendKmsRequest(KmsRequest request, CancellationToken cancellation)
{
var endpoint = CreateKmsEndPoint(request.Endpoint);

var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
using (var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation))
using (var binary = request.GetMessage())
try
{
var endpoint = CreateKmsEndPoint(request.Endpoint);

var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
using var sslStream = sslStreamFactory.CreateStream(endpoint, cancellation);

var sleepMs = request.Sleep;
if (sleepMs > 0)
{
Thread.Sleep(sleepMs);
}

using var binary = request.GetMessage();
var requestBytes = binary.ToArray();
sslStream.Write(requestBytes, 0, requestBytes.Length);

while (request.BytesNeeded > 0)
{
var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
var count = sslStream.Read(buffer, 0, buffer.Length);

if (count == 0)
{
throw new IOException("Unexpected end of stream. No data was read from the SSL stream.");
}

var responseBytes = new byte[count];
Buffer.BlockCopy(buffer, 0, responseBytes, 0, count);
request.Feed(responseBytes);
}
}
catch (Exception ex) when (ex is IOException or SocketException)
{
if (!request.Fail())
{
throw;
}
}
}

private async Task SendKmsRequestAsync(KmsRequest request, CancellationToken cancellation)
{
var endpoint = CreateKmsEndPoint(request.Endpoint);

var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
using (var sslStream = await sslStreamFactory.CreateStreamAsync(endpoint, cancellation).ConfigureAwait(false))
using (var binary = request.GetMessage())
try
{
var endpoint = CreateKmsEndPoint(request.Endpoint);

var tlsStreamSettings = GetTlsStreamSettings(request.KmsProvider);
var sslStreamFactory = new SslStreamFactory(tlsStreamSettings, _networkStreamFactory);
using var sslStream = await sslStreamFactory.CreateStreamAsync(endpoint, cancellation).ConfigureAwait(false);

var sleepMs = request.Sleep;
if (sleepMs > 0)
{
await Task.Delay(sleepMs, cancellation).ConfigureAwait(false);
}

using var binary = request.GetMessage();
var requestBytes = binary.ToArray();
await sslStream.WriteAsync(requestBytes, 0, requestBytes.Length).ConfigureAwait(false);

while (request.BytesNeeded > 0)
{
var buffer = new byte[request.BytesNeeded]; // BytesNeeded is the maximum number of bytes that libmongocrypt wants to receive.
var count = await sslStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false);

if (count == 0)
{
throw new IOException("Unexpected end of stream. No data was read from the SSL stream.");
}

var responseBytes = new byte[count];
Buffer.BlockCopy(buffer, 0, responseBytes, 0, count);
request.Feed(responseBytes);
}
}
catch (Exception ex) when (ex is IOException or SocketException)
{
if (!request.Fail())
{
throw;
}
}
}

// nested type
Expand Down
24 changes: 24 additions & 0 deletions src/MongoDB.Driver.Encryption/Library.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ static Library()
_mongocrypt_ctx_setopt_query_type = new Lazy<Delegates.mongocrypt_ctx_setopt_query_type>(
() => __loader.Value.GetFunction<Delegates.mongocrypt_ctx_setopt_query_type>(
("mongocrypt_ctx_setopt_query_type")), true);
_mongocrypt_setopt_retry_kms = new Lazy<Delegates.mongocrypt_setopt_retry_kms>(
() => __loader.Value.GetFunction<Delegates.mongocrypt_setopt_retry_kms>(
("mongocrypt_setopt_retry_kms")), true);

_mongocrypt_ctx_status = new Lazy<Delegates.mongocrypt_ctx_status>(
() => __loader.Value.GetFunction<Delegates.mongocrypt_ctx_status>(("mongocrypt_ctx_status")), true);
Expand Down Expand Up @@ -210,6 +213,11 @@ static Library()
() => __loader.Value.GetFunction<Delegates.mongocrypt_ctx_destroy>(("mongocrypt_ctx_destroy")), true);
_mongocrypt_kms_ctx_get_kms_provider = new Lazy<Delegates.mongocrypt_kms_ctx_get_kms_provider>(
() => __loader.Value.GetFunction<Delegates.mongocrypt_kms_ctx_get_kms_provider>(("mongocrypt_kms_ctx_get_kms_provider")), true);

_mongocrypt_kms_ctx_usleep = new Lazy<Delegates.mongocrypt_kms_ctx_usleep>(
() => __loader.Value.GetFunction<Delegates.mongocrypt_kms_ctx_usleep>(("mongocrypt_kms_ctx_usleep")), true);
_mongocrypt_kms_ctx_fail = new Lazy<Delegates.mongocrypt_kms_ctx_fail>(
() => __loader.Value.GetFunction<Delegates.mongocrypt_kms_ctx_fail>(("mongocrypt_kms_ctx_fail")), true);
}

/// <summary>
Expand Down Expand Up @@ -287,6 +295,7 @@ public static string Version
internal static Delegates.mongocrypt_ctx_setopt_algorithm_range mongocrypt_ctx_setopt_algorithm_range => _mongocrypt_ctx_setopt_algorithm_range.Value;
internal static Delegates.mongocrypt_ctx_setopt_contention_factor mongocrypt_ctx_setopt_contention_factor => _mongocrypt_ctx_setopt_contention_factor.Value;
internal static Delegates.mongocrypt_ctx_setopt_query_type mongocrypt_ctx_setopt_query_type => _mongocrypt_ctx_setopt_query_type.Value;
internal static Delegates.mongocrypt_setopt_retry_kms mongocrypt_setopt_retry_kms => _mongocrypt_setopt_retry_kms.Value;

internal static Delegates.mongocrypt_ctx_state mongocrypt_ctx_state => _mongocrypt_ctx_state.Value;
internal static Delegates.mongocrypt_ctx_mongo_op mongocrypt_ctx_mongo_op => _mongocrypt_ctx_mongo_op.Value;
Expand All @@ -305,6 +314,9 @@ public static string Version
internal static Delegates.mongocrypt_ctx_destroy mongocrypt_ctx_destroy => _mongocrypt_ctx_destroy.Value;
internal static Delegates.mongocrypt_kms_ctx_get_kms_provider mongocrypt_kms_ctx_get_kms_provider => _mongocrypt_kms_ctx_get_kms_provider.Value;

internal static Delegates.mongocrypt_kms_ctx_usleep mongocrypt_kms_ctx_usleep => _mongocrypt_kms_ctx_usleep.Value;
internal static Delegates.mongocrypt_kms_ctx_fail mongocrypt_kms_ctx_fail => _mongocrypt_kms_ctx_fail.Value;

private static readonly Lazy<LibraryLoader> __loader = new Lazy<LibraryLoader>(
() => new LibraryLoader(), true);
private static readonly Lazy<Delegates.mongocrypt_version> _mongocrypt_version;
Expand Down Expand Up @@ -392,6 +404,10 @@ public static string Version
private static readonly Lazy<Delegates.mongocrypt_ctx_destroy> _mongocrypt_ctx_destroy;
private static readonly Lazy<Delegates.mongocrypt_kms_ctx_get_kms_provider> _mongocrypt_kms_ctx_get_kms_provider;

private static readonly Lazy<Delegates.mongocrypt_kms_ctx_usleep> _mongocrypt_kms_ctx_usleep;
private static readonly Lazy<Delegates.mongocrypt_kms_ctx_fail> _mongocrypt_kms_ctx_fail;
private static readonly Lazy<Delegates.mongocrypt_setopt_retry_kms> _mongocrypt_setopt_retry_kms;

// nested types
internal enum StatusType
{
Expand Down Expand Up @@ -640,6 +656,9 @@ public delegate bool
[return: MarshalAs(UnmanagedType.I1)]
public delegate bool mongocrypt_ctx_setopt_query_type(ContextSafeHandle ctx, [MarshalAs(UnmanagedType.LPStr)] string query_type, int length);

[return: MarshalAs(UnmanagedType.I1)]
public delegate bool mongocrypt_setopt_retry_kms(MongoCryptSafeHandle handle, bool enable);

public delegate CryptContext.StateCode mongocrypt_ctx_state(ContextSafeHandle handle);

[return: MarshalAs(UnmanagedType.I1)]
Expand Down Expand Up @@ -681,6 +700,11 @@ public delegate bool

public delegate void mongocrypt_ctx_destroy(IntPtr ptr);
public delegate IntPtr mongocrypt_kms_ctx_get_kms_provider(IntPtr handle, out uint length);

public delegate long mongocrypt_kms_ctx_usleep(IntPtr handle);

[return: MarshalAs(UnmanagedType.I1)]
public delegate bool mongocrypt_kms_ctx_fail(IntPtr handle);
}
}
}
10 changes: 5 additions & 5 deletions src/MongoDB.Driver.Encryption/MongoDB.Driver.Encryption.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

<Target Name="DownloadNativeBinaries_MacOS" BeforeTargets="BeforeBuild" Condition="!Exists('$(MSBuildProjectDirectory)/runtimes/osx/native/libmongocrypt.dylib')">
<PropertyGroup>
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/macos/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/macos/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
<LibMongoCryptSourcePath>lib/libmongocrypt.dylib</LibMongoCryptSourcePath>
<LibMongoCryptPackagePath>runtimes/osx/native</LibMongoCryptPackagePath>
</PropertyGroup>
Expand All @@ -27,7 +27,7 @@

<Target Name="DownloadNativeBinaries_UbuntuX64" BeforeTargets="BeforeBuild" Condition="!Exists('$(MSBuildProjectDirectory)/runtimes/linux/native/x64/libmongocrypt.so')">
<PropertyGroup>
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-64/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-64/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
<LibMongoCryptSourcePath>nocrypto/lib/libmongocrypt.so</LibMongoCryptSourcePath>
<LibMongoCryptPackagePath>runtimes/linux/native/x64</LibMongoCryptPackagePath>
</PropertyGroup>
Expand All @@ -39,7 +39,7 @@

<Target Name="DownloadNativeBinaries_UbuntuARM64" BeforeTargets="BeforeBuild" Condition="!Exists('$(MSBuildProjectDirectory)/runtimes/linux/native/arm64/libmongocrypt.so')">
<PropertyGroup>
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-arm64/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/ubuntu1804-arm64/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
<LibMongoCryptSourcePath>nocrypto/lib/libmongocrypt.so</LibMongoCryptSourcePath>
<LibMongoCryptPackagePath>runtimes/linux/native/arm64</LibMongoCryptPackagePath>
</PropertyGroup>
Expand All @@ -51,7 +51,7 @@

<Target Name="DownloadNativeBinaries_Alpine" BeforeTargets="BeforeBuild" Condition="!Exists('$(MSBuildProjectDirectory)/runtimes/linux/native/alpine/libmongocrypt.so')">
<PropertyGroup>
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/alpine-arm64-earthly/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/alpine-arm64-earthly/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
<LibMongoCryptSourcePath>nocrypto/lib/libmongocrypt.so</LibMongoCryptSourcePath>
<LibMongoCryptPackagePath>runtimes/linux/native/alpine</LibMongoCryptPackagePath>
</PropertyGroup>
Expand All @@ -63,7 +63,7 @@

<Target Name="DownloadNativeBinaries_Windows" BeforeTargets="BeforeBuild" Condition="!Exists('$(MSBuildProjectDirectory)/runtimes/win/native/mongocrypt.dll')">
<PropertyGroup>
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/windows-test/r1.11/9a88ac5698e8e3ffcd6580b98c247f0126f26c40/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
<LibMongoCryptSourceUrl>https://mciuploads.s3.amazonaws.com/libmongocrypt-release/windows-test/r1.12/085a0ce6538a28179da6bfd2927aea106924443a/libmongocrypt.tar.gz</LibMongoCryptSourceUrl>
<LibMongoCryptSourcePath>bin/mongocrypt.dll</LibMongoCryptSourcePath>
<LibMongoCryptPackagePath>runtimes/win/native</LibMongoCryptPackagePath>
</PropertyGroup>
Expand Down
13 changes: 6 additions & 7 deletions tests/MongoDB.Driver.Encryption.Tests/BasicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ public void TestGetKmsProviderName(string kmsName)
using (var cryptClient = CryptClientFactory.Create(cryptOptions))
using (var context = cryptClient.StartCreateDataKeyContext(keyId))
{
var request = context.GetKmsMessageRequests().Single();
var request = context.GetNextKmsMessageRequest();
request.KmsProvider.Should().Be(kmsName);
}
}
Expand Down Expand Up @@ -632,22 +632,21 @@ private static (CryptContext.StateCode stateProcessed, Binary binaryProduced, Bs

case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS:
{
var requests = context.GetKmsMessageRequests();
foreach (var req in requests)
while (context.GetNextKmsMessageRequest() is { } request)
{
using var binary = req.GetMessage();
using var binary = request.GetMessage();
_output.WriteLine("Key Document: " + binary);
var postRequest = binary.ToString();
// TODO: add different hosts handling
postRequest.Should().Contain("Host:kms.us-east-1.amazonaws.com"); // only AWS

var reply = ReadHttpTestFile(isKmsDecrypt ? "kms-decrypt-reply.txt" : "kms-encrypt-reply.txt");
_output.WriteLine("Reply: " + reply);
req.Feed(Encoding.UTF8.GetBytes(reply));
req.BytesNeeded.Should().Be(0);
request.Feed(Encoding.UTF8.GetBytes(reply));
request.BytesNeeded.Should().Be(0);
}

requests.MarkDone();
context.MarkKmsDone();
return (CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS, null, null);
}

Expand Down
Loading

0 comments on commit e3f3943

Please sign in to comment.