diff --git a/src/Ocelot/DownstreamRouteFinder/Finder/DownstreamRouteCreator.cs b/src/Ocelot/DownstreamRouteFinder/Finder/DownstreamRouteCreator.cs index 71940d18d..b654d32d2 100644 --- a/src/Ocelot/DownstreamRouteFinder/Finder/DownstreamRouteCreator.cs +++ b/src/Ocelot/DownstreamRouteFinder/Finder/DownstreamRouteCreator.cs @@ -61,8 +61,8 @@ public Response Get(string upstreamUrlPath, string upstreamHttp downstreamRoute = new OkResponse(new DownstreamRoute(new List(), reRoute)); - _cache.AddOrUpdate(loadBalancerKey, downstreamRoute, (x, y) => downstreamRoute); - + _cache.AddOrUpdate(loadBalancerKey, downstreamRoute, (x, y) => downstreamRoute); + return downstreamRoute; } diff --git a/src/Ocelot/LoadBalancer/LoadBalancers/NoLoadBalancer.cs b/src/Ocelot/LoadBalancer/LoadBalancers/NoLoadBalancer.cs index 3f69adedc..c24a50d2b 100644 --- a/src/Ocelot/LoadBalancer/LoadBalancers/NoLoadBalancer.cs +++ b/src/Ocelot/LoadBalancer/LoadBalancers/NoLoadBalancer.cs @@ -20,7 +20,7 @@ public NoLoadBalancer(Func>> services) public async Task> Lease(DownstreamContext downstreamContext) { var services = await _services(); - //todo first or default could be null.. + if (services == null || services.Count == 0) { return new ErrorResponse(new ServicesAreEmptyError("There were no services in NoLoadBalancer")); diff --git a/src/Ocelot/Raft/SqlLiteLog.cs b/src/Ocelot/Raft/SqlLiteLog.cs index f4dfed499..99cd03085 100644 --- a/src/Ocelot/Raft/SqlLiteLog.cs +++ b/src/Ocelot/Raft/SqlLiteLog.cs @@ -99,6 +99,7 @@ public async Task LastLogTerm() } } } + _sempaphore.Release(); return result; } @@ -120,6 +121,7 @@ public async Task Count() } } } + _sempaphore.Release(); return result; } @@ -135,6 +137,7 @@ public async Task Apply(LogEntry log) TypeNameHandling = TypeNameHandling.All }; var data = JsonConvert.SerializeObject(log, jsonSerializerSettings); + //todo - sql injection dont copy this.. var sql = $"insert into logs (data) values ('{data}')"; _logger.LogInformation($"id: {_nodeId.Id}, sql: {sql}"); @@ -162,6 +165,7 @@ public async Task DeleteConflictsFromThisLog(int index, LogEntry logEntry) using (var connection = new SqliteConnection($"Data Source={_path};")) { connection.Open(); + //todo - sql injection dont copy this.. var sql = $"select data from logs where id = {index};"; _logger.LogInformation($"id: {_nodeId.Id} sql: {sql}"); @@ -188,6 +192,7 @@ public async Task DeleteConflictsFromThisLog(int index, LogEntry logEntry) } } } + _sempaphore.Release(); } @@ -197,6 +202,7 @@ public async Task IsDuplicate(int index, LogEntry logEntry) using (var connection = new SqliteConnection($"Data Source={_path};")) { connection.Open(); + //todo - sql injection dont copy this.. var sql = $"select data from logs where id = {index};"; using (var command = new SqliteCommand(sql, connection)) @@ -227,6 +233,7 @@ public async Task Get(int index) using (var connection = new SqliteConnection($"Data Source={_path};")) { connection.Open(); + //todo - sql injection dont copy this.. var sql = $"select data from logs where id = {index}"; using (var command = new SqliteCommand(sql, connection)) @@ -251,6 +258,7 @@ public async Task Get(int index) using (var connection = new SqliteConnection($"Data Source={_path};")) { connection.Open(); + //todo - sql injection dont copy this.. var sql = $"select id, data from logs where id >= {index}"; using (var command = new SqliteCommand(sql, connection)) @@ -267,10 +275,10 @@ public async Task Get(int index) }; var log = JsonConvert.DeserializeObject(data, jsonSerializerSettings); logsToReturn.Add((id, log)); - } } } + _sempaphore.Release(); return logsToReturn; } @@ -283,6 +291,7 @@ public async Task GetTermAtIndex(int index) using (var connection = new SqliteConnection($"Data Source={_path};")) { connection.Open(); + //todo - sql injection dont copy this.. var sql = $"select data from logs where id = {index}"; using (var command = new SqliteCommand(sql, connection)) @@ -299,15 +308,18 @@ public async Task GetTermAtIndex(int index) } } } + _sempaphore.Release(); return result; } + public async Task Remove(int indexOfCommand) { _sempaphore.Wait(); using (var connection = new SqliteConnection($"Data Source={_path};")) { connection.Open(); + //todo - sql injection dont copy this.. var deleteSql = $"delete from logs where id >= {indexOfCommand};"; _logger.LogInformation($"id: {_nodeId.Id} Remove {deleteSql}"); @@ -316,6 +328,7 @@ public async Task Remove(int indexOfCommand) var result = await deleteCommand.ExecuteNonQueryAsync(); } } + _sempaphore.Release(); } } diff --git a/src/Ocelot/Requester/IHttpClientCache.cs b/src/Ocelot/Requester/IHttpClientCache.cs index ce80dde36..2c4571ce6 100644 --- a/src/Ocelot/Requester/IHttpClientCache.cs +++ b/src/Ocelot/Requester/IHttpClientCache.cs @@ -1,16 +1,10 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Net.Http; -using System.Threading.Tasks; - -namespace Ocelot.Requester +namespace Ocelot.Requester { + using System; + public interface IHttpClientCache { - bool Exists(string id); IHttpClient Get(string id); - void Remove(string id); void Set(string id, IHttpClient handler, TimeSpan expirationTime); } } diff --git a/src/Ocelot/Requester/MemoryHttpClientCache.cs b/src/Ocelot/Requester/MemoryHttpClientCache.cs index e37a46a3f..9e4059e57 100644 --- a/src/Ocelot/Requester/MemoryHttpClientCache.cs +++ b/src/Ocelot/Requester/MemoryHttpClientCache.cs @@ -1,21 +1,20 @@ -using Microsoft.Extensions.Caching.Memory; -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using System.Net.Http; -using System.Threading.Tasks; - -namespace Ocelot.Requester +namespace Ocelot.Requester { + using System; + using System.Collections.Concurrent; + public class MemoryHttpClientCache : IHttpClientCache { - private readonly ConcurrentDictionary> _httpClientsCache = new ConcurrentDictionary>(); + private readonly ConcurrentDictionary> _httpClientsCache; + + public MemoryHttpClientCache() + { + _httpClientsCache = new ConcurrentDictionary>(); + } public void Set(string id, IHttpClient client, TimeSpan expirationTime) { - ConcurrentQueue connectionQueue; - if (_httpClientsCache.TryGetValue(id, out connectionQueue)) + if (_httpClientsCache.TryGetValue(id, out var connectionQueue)) { connectionQueue.Enqueue(client); } @@ -27,28 +26,15 @@ public void Set(string id, IHttpClient client, TimeSpan expirationTime) } } - public bool Exists(string id) - { - ConcurrentQueue connectionQueue; - return _httpClientsCache.TryGetValue(id, out connectionQueue); - } - public IHttpClient Get(string id) { IHttpClient client= null; - ConcurrentQueue connectionQueue; - if (_httpClientsCache.TryGetValue(id, out connectionQueue)) + if (_httpClientsCache.TryGetValue(id, out var connectionQueue)) { connectionQueue.TryDequeue(out client); } return client; - } - - public void Remove(string id) - { - ConcurrentQueue connectionQueue; - _httpClientsCache.TryRemove(id, out connectionQueue); - } + } } } diff --git a/test/Ocelot.UnitTests/Middleware/OcelotPipelineExtensionsTests.cs b/test/Ocelot.UnitTests/Middleware/OcelotPipelineExtensionsTests.cs index 717144b0a..74f6c7879 100644 --- a/test/Ocelot.UnitTests/Middleware/OcelotPipelineExtensionsTests.cs +++ b/test/Ocelot.UnitTests/Middleware/OcelotPipelineExtensionsTests.cs @@ -45,7 +45,6 @@ private void GivenTheDepedenciesAreSetUp() services.AddDiscoveryClient(new DiscoveryOptions { ClientType = DiscoveryClientType.EUREKA, - //options can not be null ClientOptions = new EurekaClientOptions() { ShouldFetchRegistry = false, diff --git a/test/Ocelot.UnitTests/Requester/HttpClientBuilderTests.cs b/test/Ocelot.UnitTests/Requester/HttpClientBuilderTests.cs index 1dd705915..429c5a90e 100644 --- a/test/Ocelot.UnitTests/Requester/HttpClientBuilderTests.cs +++ b/test/Ocelot.UnitTests/Requester/HttpClientBuilderTests.cs @@ -23,15 +23,18 @@ namespace Ocelot.UnitTests.Requester { public class HttpClientBuilderTests : IDisposable { - private readonly HttpClientBuilder _builder; + private HttpClientBuilder _builder; private readonly Mock _factory; private IHttpClient _httpClient; private HttpResponseMessage _response; private DownstreamContext _context; private readonly Mock _cacheHandlers; - private Mock _logger; + private readonly Mock _logger; private int _count; private IWebHost _host; + private IHttpClient _againHttpClient; + private IHttpClient _firstHttpClient; + private MemoryHttpClientCache _realCache; public HttpClientBuilderTests() { @@ -61,6 +64,47 @@ public void should_build_http_client() .BDDfy(); } + [Fact] + public void should_get_from_cache() + { + var qosOptions = new QoSOptionsBuilder() + .Build(); + + var reRoute = new DownstreamReRouteBuilder() + .WithQosOptions(qosOptions) + .WithHttpHandlerOptions(new HttpHandlerOptions(false, false, false, true)) + .WithLoadBalancerKey("") + .WithQosOptions(new QoSOptionsBuilder().Build()) + .Build(); + + this.Given(x => GivenARealCache()) + .And(x => GivenTheFactoryReturns()) + .And(x => GivenARequest(reRoute)) + .And(x => WhenIBuildTheFirstTime()) + .And(x => WhenISave()) + .And(x => WhenIBuildAgain()) + .And(x => WhenISave()) + .When(x => WhenIBuildAgain()) + .Then(x => ThenTheHttpClientIsFromTheCache()) + .BDDfy(); + } + + private void GivenARealCache() + { + _realCache = new MemoryHttpClientCache(); + _builder = new HttpClientBuilder(_factory.Object, _realCache, _logger.Object); + } + + private void ThenTheHttpClientIsFromTheCache() + { + _againHttpClient.ShouldBe(_firstHttpClient); + } + + private void WhenISave() + { + _builder.Save(); + } + [Fact] public void should_log_if_ignoring_ssl_errors() { @@ -302,6 +346,17 @@ private void WhenIBuild() _httpClient = _builder.Create(_context); } + private void WhenIBuildTheFirstTime() + { + _firstHttpClient = _builder.Create(_context); + } + + private void WhenIBuildAgain() + { + _builder = new HttpClientBuilder(_factory.Object, _realCache, _logger.Object); + _againHttpClient = _builder.Create(_context); + } + private void ThenTheHttpClientShouldNotBeNull() { _httpClient.ShouldNotBeNull(); diff --git a/test/Ocelot.UnitTests/ServiceDiscovery/PollingConsulServiceDiscoveryProviderTests.cs b/test/Ocelot.UnitTests/ServiceDiscovery/PollingConsulServiceDiscoveryProviderTests.cs index e75bfbcae..86d523e2d 100644 --- a/test/Ocelot.UnitTests/ServiceDiscovery/PollingConsulServiceDiscoveryProviderTests.cs +++ b/test/Ocelot.UnitTests/ServiceDiscovery/PollingConsulServiceDiscoveryProviderTests.cs @@ -1,33 +1,24 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using Consul; -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Http; -using Moq; -using Ocelot.Infrastructure.Consul; -using Ocelot.Logging; -using Ocelot.ServiceDiscovery.Configuration; -using Ocelot.ServiceDiscovery.Providers; -using Ocelot.Values; -using Xunit; -using TestStack.BDDfy; -using Shouldly; -using static Ocelot.Infrastructure.Wait; - -namespace Ocelot.UnitTests.ServiceDiscovery +namespace Ocelot.UnitTests.ServiceDiscovery { + using System; + using System.Collections.Generic; + using Moq; + using Ocelot.Logging; + using Ocelot.ServiceDiscovery.Providers; + using Ocelot.Values; + using Xunit; + using TestStack.BDDfy; + using Shouldly; + using static Ocelot.Infrastructure.Wait; + public class PollingConsulServiceDiscoveryProviderTests { private readonly int _delay; private PollingConsulServiceDiscoveryProvider _provider; - private readonly string _serviceName; - private List _services; + private readonly List _services; private readonly Mock _factory; private readonly Mock _logger; - private Mock _consulServiceDiscoveryProvider; + private readonly Mock _consulServiceDiscoveryProvider; private List _result; public PollingConsulServiceDiscoveryProviderTests() @@ -64,7 +55,7 @@ private void ThenTheCountIs(int count) private void WhenIGetTheServices(int expected) { - _provider = new PollingConsulServiceDiscoveryProvider(_delay, _serviceName, _factory.Object, _consulServiceDiscoveryProvider.Object); + _provider = new PollingConsulServiceDiscoveryProvider(_delay, "", _factory.Object, _consulServiceDiscoveryProvider.Object); var result = WaitFor(3000).Until(() => { try