Skip to content

Commit

Permalink
Simplify TLSSecurityMode parsing. (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwpark authored Feb 14, 2025
1 parent 09ff32d commit 57d9d81
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 71 deletions.
15 changes: 2 additions & 13 deletions src/Gel.Net.Driver/GelConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ internal class ConnectionCredentials
public string? TlsCA { get; init; }

[JsonProperty("tls_security")]
[JsonConverter(typeof(TLSSecurityModeParser))]
public TLSSecurityMode? TlsSecurity { get; init; }
public string? TlsSecurity { get; init; }
}


Expand Down Expand Up @@ -1022,17 +1021,7 @@ internal static ConfigUtils.ResolvedFields _FromDSN(string dsn, ISystemProvider
resolvedFields.TLSServerName = value;
break;
case "tls_security":
resolvedFields.TLSSecurity = value.Convert<TLSSecurityMode>(v =>
{
try
{
return TLSSecurityModeParser.Parse(v);
}
catch (Exception e)
{
return e;
}
});
resolvedFields.TLSSecurity = value.Convert(ConfigUtils.ParseTLSSecurityMode);
break;
case "wait_until_available":
resolvedFields.WaitUntilAvailable = value.Convert(ConfigUtils.ParseWaitUntilAvailable);
Expand Down
42 changes: 2 additions & 40 deletions src/Gel.Net.Driver/Models/TLSSecurityMode.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Newtonsoft.Json;


namespace Gel;

/// <summary>
Expand Down Expand Up @@ -31,7 +30,7 @@ public enum TLSSecurityMode
Default = Strict
}

internal class TLSSecurityModeParser : JsonConverter<TLSSecurityMode?>
internal class TLSSecurityModeParser
{
internal static bool TryParse(string text, bool emptyAsDefault, out TLSSecurityMode? tlsSecurity)
{
Expand Down Expand Up @@ -59,41 +58,4 @@ internal static bool TryParse(string text, bool emptyAsDefault, out TLSSecurityM
return false;
}

public static TLSSecurityMode Parse(string text, bool emptyAsDefault = false)
{
if (TryParse(text, emptyAsDefault, out TLSSecurityMode? tlsSecurity))
{
return tlsSecurity ?? TLSSecurityMode.Default;
}
else
{
throw new ConfigurationException(
$"Invalid TLS Security: \"{text}\", "
+ "must be one of \"insecure\", \"no_host_verification\", \"strict\", or \"default\"");
}
}

// Json conversion
public override TLSSecurityMode? ReadJson(
JsonReader reader,
Type objectType,
TLSSecurityMode? existingValue,
bool hasExistingValue,
JsonSerializer serializer)
{
if (reader.TokenType == JsonToken.String)
{
return Parse((string)reader.Value!, true);
}
else
{
throw new JsonException("Expected String.");
}
}

public override void WriteJson(
JsonWriter writer, TLSSecurityMode? value, JsonSerializer serializer)
{
throw new NotImplementedException();
}
}
21 changes: 19 additions & 2 deletions src/Gel.Net.Driver/Utils/ConfigUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ internal static ResolvedFields FromCredentials(ConnectionCredentials credentials
if (credentials.Host is not null) { result.Host = credentials.Host; }
if (credentials.Port is not null)
{
result.Port = MergeField(result.Port, ParsePort(credentials.Port));
result.Port = ParsePort(credentials.Port) ?? result.Port;
}
if (credentials.Database is not null)
{
Expand All @@ -289,7 +289,10 @@ internal static ResolvedFields FromCredentials(ConnectionCredentials credentials
if (credentials.User is not null) { result.User = credentials.User; }
if (credentials.Password is not null) { result.Password = credentials.Password; }
if (credentials.TlsCA is not null) { result.TLSCertificateAuthority = credentials.TlsCA; }
if (credentials.TlsSecurity is not null) { result.TLSSecurity = credentials.TlsSecurity; }
if (credentials.TlsSecurity is not null)
{
result.TLSSecurity = ParseTLSSecurityMode(credentials.TlsSecurity);
}

return result;
}
Expand All @@ -315,6 +318,20 @@ internal static ResolvedFields FromCredentials(ConnectionCredentials credentials
}
}

public static ResolvedField<TLSSecurityMode> ParseTLSSecurityMode(string text)
{
if (TLSSecurityModeParser.TryParse(text, false, out TLSSecurityMode? tlsSecurity))
{
return tlsSecurity ?? TLSSecurityMode.Default;
}
else
{
return new ConfigurationException(
$"Invalid TLS Security: \"{text}\", "
+ "must be one of \"insecure\", \"no_host_verification\", \"strict\", or \"default\"");
}
}

private static readonly Regex _isoUnitlessHours = new Regex(
@"^(-?\d+|-?\d+\.\d*|-?\d*\.\d+)$",
RegexOptions.Compiled);
Expand Down
50 changes: 34 additions & 16 deletions tests/Gel.Tests.Unit/SharedClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,33 +114,44 @@ private static TestResult ParseConnection(TestCase testCase)

try
{
int? optionsPort = null;
if (testCase?.Options?.Port is not null)
{
if (int.TryParse(testCase?.Options?.Port, out var parsedPort))
{
optionsPort = parsedPort;
}
else
{
throw new ConfigurationException(
$"Invalid port: {testCase?.Options?.Port}, not an integer");
}
}

TLSSecurityMode? tlsSecurity = null;
if (testCase?.Options?.TlsSecurity is not null)
{
tlsSecurity =
ConfigUtils.ParseTLSSecurityMode(testCase.Options.TlsSecurity)
.CheckAndGetValue();
}

GelConnection.Options config = new()
{
Instance = testCase?.Options?.Instance,
Dsn = testCase?.Options?.Dsn,
Credentials = testCase?.Options?.Credentials,
CredentialsFile = testCase?.Options?.CredentialsFile,
Host = testCase?.Options?.Host,
Port = (
testCase?.Options?.Port is null
? null
: int.TryParse(testCase?.Options?.Port, out var parsedPort)
? parsedPort
: throw new ConfigurationException(
$"Invalid port: {testCase?.Options?.Port}, not an integer")
),
Port = optionsPort,
Database = testCase?.Options?.Database,
Branch = testCase?.Options?.Branch,
User = testCase?.Options?.User,
Password = testCase?.Options?.Password,
SecretKey = testCase?.Options?.SecretKey,
TLSCertificateAuthority = testCase?.Options?.TlsCA,
TLSCertificateAuthorityFile = testCase?.Options?.TlsCAFile,
TLSSecurity = (
testCase?.Options?.TlsSecurity is null
? null
: TLSSecurityModeParser.Parse(testCase.Options.TlsSecurity)
),
TLSSecurity = tlsSecurity,
TLSServerName = testCase?.Options?.TlsServerName,
WaitUntilAvailable = testCase?.Options?.WaitUntilAvailable,
ServerSettings = testCase?.Options?.ServerSettings,
Expand Down Expand Up @@ -172,7 +183,15 @@ private static void AssertSameConnection(TestResult result, TestCase.ExpectedRes
string expectedPassword = expectedResult.Password ?? "";
string? expectedSecretKey = expectedResult.SecretKey;
string? expectedTLSCertificateAuthority = expectedResult.TlsCAData;
TLSSecurityMode expectedTLSSecurity = expectedResult.TlsSecurity ?? TLSSecurityMode.Strict;
TLSSecurityMode expectedTLSSecurity = TLSSecurityMode.Default;
if (expectedResult.TlsSecurity is not null)
{
if (TLSSecurityModeParser.TryParse(
expectedResult.TlsSecurity, false, out var parsedTlsSecurity))
{
expectedTLSSecurity = parsedTlsSecurity ?? TLSSecurityMode.Default;
}
};
string? expectedTLSServerName = expectedResult.TlsServerName;
int expectedWaitUntilAvailable =
expectedResult.WaitUntilAvailable is not null
Expand Down Expand Up @@ -600,8 +619,7 @@ public class ExpectedResult
public string? TlsCAData { get; init; }

[JsonProperty("tlsSecurity")]
[JsonConverter(typeof(TLSSecurityModeParser))]
public TLSSecurityMode? TlsSecurity { get; init; }
public string? TlsSecurity { get; init; }

[JsonProperty("tlsServerName")]
public string? TlsServerName { get; init; }
Expand Down

0 comments on commit 57d9d81

Please sign in to comment.