diff --git a/README.md b/README.md index bb66ba6a..fd3f1551 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,25 @@ -**Rate-limiting pattern** +**Rate-limiting pattern** -Rate limiting involves restricting the number of requests that can be made by a client. -A client is identified with an access token, which is used for every request to a resource. -To prevent abuse of the server, APIs enforce rate-limiting techniques. -Based on the client, the rate-limiting application can decide whether to allow the request to go through or not. -The client makes an API call to a particular resource; the server checks whether the request for this client is within the limit. -If the request is within the limit, then the request goes through. -Otherwise, the API call is restricted. +Based on RateLimiting service from [ocelot repo](https://github.com/ThreeMammals/Ocelot). +You can see an example of using in the [SimpleSample](https://github.com/DAKnyazev/rate-limiter/tree/master/Samples/SimpleSample) project. -Some examples of request-limiting rules (you could imagine any others) -* X requests per timespan; -* a certain timespan passed since the last call; -* for US-based tokens, we use X requests per timespan, for EU-based - certain timespan passed since the last call. +## Usage -The goal is to design a class(-es) that manage rate limits for every provided API resource by a set of provided *configurable and extendable* rules. For example, for one resource you could configure the limiter to use Rule A, for another one - Rule B, for a third one - both A + B, etc. Any combinations of rules should be possible, keep this fact in mind when designing the classes. +To access rate limiting counter you need to resolve service 'IRateLimitingService' -We're more interested in the design itself than in some smart and tricky rate limiting algorithm. There is no need to use neither database (in-memory storage is fine) nor any web framework. Do not waste time on preparing complex environment, reusable class library covered by a set of tests is more than enough. +### With memory cache +```csharp + .AddRateLimitingServiceWithMemoryCache(); +``` -There is a Test Project set up for you to use. You are welcome to create your own test project and use whatever test runner you would like. +### With distributed cache +```csharp + .AddRateLimitingServiceWithDistributedCache(); +``` -You are welcome to ask any questions regarding the requirements - treat us as product owners/analysts/whoever who knows the business. -Should you have any questions or concerns, submit them as a [GitHub issue](https://github.com/crexi-dev/rate-limiter/issues). - -You should [fork](https://help.github.com/en/github/getting-started-with-github/fork-a-repo) the project, and [create a pull request](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request-from-a-fork) once you are finished. - -Good luck! +### With custom cache +```csharp + .AddRateLimitingServiceCore() + .AddSingleton(); +``` +You need to implement IRateLimitStorageService by yourself. diff --git a/RateLimiter.Tests/RateLimiter.Tests.csproj b/RateLimiter.Tests/RateLimiter.Tests.csproj index 5cbfc4e8..6eddb166 100644 --- a/RateLimiter.Tests/RateLimiter.Tests.csproj +++ b/RateLimiter.Tests/RateLimiter.Tests.csproj @@ -8,8 +8,13 @@ + - - + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + - \ No newline at end of file + \ No newline at end of file diff --git a/RateLimiter.Tests/RateLimiterTest.cs b/RateLimiter.Tests/RateLimiterTest.cs deleted file mode 100644 index 172d44a7..00000000 --- a/RateLimiter.Tests/RateLimiterTest.cs +++ /dev/null @@ -1,13 +0,0 @@ -using NUnit.Framework; - -namespace RateLimiter.Tests; - -[TestFixture] -public class RateLimiterTest -{ - [Test] - public void Example() - { - Assert.That(true, Is.True); - } -} \ No newline at end of file diff --git a/RateLimiter.Tests/RateLimitingServiceTests.cs b/RateLimiter.Tests/RateLimitingServiceTests.cs new file mode 100644 index 00000000..f45acd8b --- /dev/null +++ b/RateLimiter.Tests/RateLimitingServiceTests.cs @@ -0,0 +1,163 @@ +using System; +using FluentAssertions; +using Moq; +using RateLimiter.Extensions; +using RateLimiter.Models; +using RateLimiter.Services; +using RateLimiter.Services.Interfaces; +using Xunit; + +namespace RateLimiter.Tests; + +public class RateLimitingServiceTests +{ + private const long MaxRequestsPerPeriod = 3; + private static readonly TimeSpan Period = TimeSpan.FromSeconds(5); + private static readonly DateTime Now = DateTime.Now; + + private readonly IRateLimitingService _rateLimitingService; + + private readonly Mock _rateLimitStorageServiceMock; + private readonly Mock _dateTimeProviderMock; + private readonly RateLimitRule _rateLimitRule; + + public RateLimitingServiceTests() + { + _rateLimitStorageServiceMock = new Mock(); + _dateTimeProviderMock = new Mock(MockBehavior.Strict); + _dateTimeProviderMock.SetupGet(x => x.UtcNow).Returns(Now); + + _rateLimitingService = new RateLimitingService( + _rateLimitStorageServiceMock.Object, + _dateTimeProviderMock.Object); + + _rateLimitRule = new RateLimitRule(Period, MaxRequestsPerPeriod); + } + + [Fact] + public void GetRateLimitCounter_FirstRequest_ShouldReturnDefaultRateLimitCounter() + { + // Arrange + var identity = new ClientRequestIdentity(Guid.NewGuid().ToString(), "/create", "POST"); + var key = identity.GetStorageKey(Period); + _rateLimitStorageServiceMock.Setup(x => x.Get(key)).Returns((RateLimitCounter?)null); + + // Act + var result = _rateLimitingService.GetRateLimitCounter(identity, _rateLimitRule); + + // Assert + result.StartedAt.Should().Be(Now); + result.ExceededAt.Should().BeNull(); + result.TotalRequests.Should().Be(1); + _rateLimitStorageServiceMock.VerifyAll(); + _rateLimitStorageServiceMock + .Verify(x => + x.Set( + key, + It.Is(c => + c.StartedAt == Now + && c.ExceededAt.HasValue == false + && c.TotalRequests == 1), + Period), + Times.Once); + _rateLimitStorageServiceMock.VerifyNoOtherCalls(); + } + + [Fact] + public void GetRateLimitCounter_OneRequestUntilBan_ShouldReturnRateLimitCounterWithMaxRequestsPerPeriod() + { + // Arrange + var identity = new ClientRequestIdentity(Guid.NewGuid().ToString(), "/create", "POST"); + var key = identity.GetStorageKey(Period); + var halfOfPeriod = Period / 2; + var startedAt = Now.Add(-halfOfPeriod); + var exceededAt = Now.Add(halfOfPeriod); + _rateLimitStorageServiceMock + .Setup(x => x.Get(key)) + .Returns(new RateLimitCounter( + startedAt, + exceededAt, + MaxRequestsPerPeriod - 1)); + + // Act + var result = _rateLimitingService.GetRateLimitCounter(identity, _rateLimitRule); + + // Assert + result.StartedAt.Should().Be(startedAt); + result.ExceededAt.Should().Be(exceededAt); + result.TotalRequests.Should().Be(MaxRequestsPerPeriod); + _rateLimitStorageServiceMock.VerifyAll(); + _rateLimitStorageServiceMock + .Verify(x => + x.Set( + key, + It.Is(c => + c.StartedAt == startedAt + && c.ExceededAt == exceededAt + && c.TotalRequests == MaxRequestsPerPeriod), + Period), + Times.Once); + _rateLimitStorageServiceMock.VerifyNoOtherCalls(); + } + + [Fact] + public void GetRateLimitCounter_BanExpired_ShouldReturnDefaultRateLimitCounter() + { + // Arrange + var identity = new ClientRequestIdentity(Guid.NewGuid().ToString(), "/create", "POST"); + var key = identity.GetStorageKey(Period); + var startedAt = Now.Add(-Period).AddSeconds(-1); + var exceededAt = Now.Add(-Period).AddMilliseconds(-1); + _rateLimitStorageServiceMock + .Setup(x => x.Get(key)) + .Returns(new RateLimitCounter( + startedAt, + exceededAt, + MaxRequestsPerPeriod)); + + // Act + var result = _rateLimitingService.GetRateLimitCounter(identity, _rateLimitRule); + + // Assert + result.StartedAt.Should().Be(Now); + result.ExceededAt.Should().BeNull(); + result.TotalRequests.Should().Be(1); + _rateLimitStorageServiceMock.VerifyAll(); + _rateLimitStorageServiceMock + .Verify(x => + x.Set( + key, + It.Is(c => + c.StartedAt == Now + && c.ExceededAt.HasValue == false + && c.TotalRequests == 1), + Period), + Times.Once); + _rateLimitStorageServiceMock.VerifyNoOtherCalls(); + } + + [Fact] + public void GetRateLimitCounter_BanNotExpired_ShouldReturnRateLimitCounterWithBan() + { + // Arrange + var identity = new ClientRequestIdentity(Guid.NewGuid().ToString(), "/create", "POST"); + var key = identity.GetStorageKey(Period); + var startedAt = Now.Add(-Period).AddSeconds(-1); + var exceededAt = Now.Add(-Period).AddMilliseconds(1); + + _rateLimitStorageServiceMock + .Setup(x => x.Get(key)) + .Returns(new RateLimitCounter( + startedAt, + exceededAt, + MaxRequestsPerPeriod)); + + // Act + var result = _rateLimitingService.GetRateLimitCounter(identity, _rateLimitRule); + + // Assert + result.StartedAt.Should().Be(startedAt); + result.ExceededAt.Should().Be(exceededAt); + result.TotalRequests.Should().Be(MaxRequestsPerPeriod + 1); + } +} diff --git a/RateLimiter.sln b/RateLimiter.sln index 626a7bfa..18d0f864 100644 --- a/RateLimiter.sln +++ b/RateLimiter.sln @@ -1,17 +1,19 @@  Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 15 -VisualStudioVersion = 15.0.26730.15 +# Visual Studio Version 17 +VisualStudioVersion = 17.9.34902.65 MinimumVisualStudioVersion = 10.0.40219.1 -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RateLimiter", "RateLimiter\RateLimiter.csproj", "{36F4BDC6-D3DA-403A-8DB7-0C79F94B938F}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "RateLimiter", "RateLimiter\RateLimiter.csproj", "{36F4BDC6-D3DA-403A-8DB7-0C79F94B938F}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "RateLimiter.Tests", "RateLimiter.Tests\RateLimiter.Tests.csproj", "{C4F9249B-010E-46BE-94B8-DD20D82F1E60}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "RateLimiter.Tests", "RateLimiter.Tests\RateLimiter.Tests.csproj", "{C4F9249B-010E-46BE-94B8-DD20D82F1E60}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{9B206889-9841-4B5E-B79B-D5B2610CCCFF}" ProjectSection(SolutionItems) = preProject README.md = README.md EndProjectSection EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "SimpleSample", "Samples\SimpleSample\SimpleSample.csproj", "{DCC2194E-DB5B-4823-9347-E8A6728C59EB}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -26,10 +28,17 @@ Global {C4F9249B-010E-46BE-94B8-DD20D82F1E60}.Debug|Any CPU.Build.0 = Debug|Any CPU {C4F9249B-010E-46BE-94B8-DD20D82F1E60}.Release|Any CPU.ActiveCfg = Release|Any CPU {C4F9249B-010E-46BE-94B8-DD20D82F1E60}.Release|Any CPU.Build.0 = Release|Any CPU + {DCC2194E-DB5B-4823-9347-E8A6728C59EB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {DCC2194E-DB5B-4823-9347-E8A6728C59EB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {DCC2194E-DB5B-4823-9347-E8A6728C59EB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {DCC2194E-DB5B-4823-9347-E8A6728C59EB}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection + GlobalSection(NestedProjects) = preSolution + {DCC2194E-DB5B-4823-9347-E8A6728C59EB} = {9B206889-9841-4B5E-B79B-D5B2610CCCFF} + EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {67D05CB6-8603-4C96-97E5-C6CEFBEC6134} EndGlobalSection diff --git a/RateLimiter/Extensions/ClientRequestIdentityExtensions.cs b/RateLimiter/Extensions/ClientRequestIdentityExtensions.cs new file mode 100644 index 00000000..e0b804dd --- /dev/null +++ b/RateLimiter/Extensions/ClientRequestIdentityExtensions.cs @@ -0,0 +1,23 @@ +using System; +using System.Security.Cryptography; +using System.Text; +using RateLimiter.Models; + +namespace RateLimiter.Extensions; + +internal static class ClientRequestIdentityExtensions +{ + internal static string GetStorageKey(this ClientRequestIdentity identity, TimeSpan period) + { + var key = $"{identity.ClientId}_{period}_{identity.HttpVerb}_{identity.Path}"; + var idBytes = Encoding.UTF8.GetBytes(key); + + byte[] hashBytes; + using (var algorithm = SHA1.Create()) + { + hashBytes = algorithm.ComputeHash(idBytes); + } + + return BitConverter.ToString(hashBytes).Replace("-", string.Empty); + } +} diff --git a/RateLimiter/IRateLimitingService.cs b/RateLimiter/IRateLimitingService.cs new file mode 100644 index 00000000..14c94e79 --- /dev/null +++ b/RateLimiter/IRateLimitingService.cs @@ -0,0 +1,11 @@ +using RateLimiter.Models; + +namespace RateLimiter; + +/// +/// A service, providing information about count of requests that clients made to the endpoint +/// +public interface IRateLimitingService +{ + RateLimitCounter GetRateLimitCounter(ClientRequestIdentity identity, RateLimitRule rule); +} diff --git a/RateLimiter/Models/ClientRequestIdentity.cs b/RateLimiter/Models/ClientRequestIdentity.cs new file mode 100644 index 00000000..059bfb80 --- /dev/null +++ b/RateLimiter/Models/ClientRequestIdentity.cs @@ -0,0 +1,15 @@ +namespace RateLimiter.Models; + +public sealed class ClientRequestIdentity +{ + public ClientRequestIdentity(string clientId, string path, string httpverb) + { + ClientId = clientId; + Path = path; + HttpVerb = httpverb; + } + + public string ClientId { get; } + public string Path { get; } + public string HttpVerb { get; } +} diff --git a/RateLimiter/Models/RateLimitCounter.cs b/RateLimiter/Models/RateLimitCounter.cs new file mode 100644 index 00000000..41a03ac2 --- /dev/null +++ b/RateLimiter/Models/RateLimitCounter.cs @@ -0,0 +1,22 @@ +using System; + +namespace RateLimiter.Models; + +/// +/// Stores the initial access time and the numbers of calls made from that point. +/// +public struct RateLimitCounter +{ + public RateLimitCounter(DateTime startedAt, DateTime? exceededAt, long totalRequests) + { + StartedAt = startedAt; + ExceededAt = exceededAt; + TotalRequests = totalRequests; + } + + public DateTime StartedAt { get; } + + public DateTime? ExceededAt { get; } + + public long TotalRequests { get; set; } +} diff --git a/RateLimiter/Models/RateLimitRule.cs b/RateLimiter/Models/RateLimitRule.cs new file mode 100644 index 00000000..02201cb8 --- /dev/null +++ b/RateLimiter/Models/RateLimitRule.cs @@ -0,0 +1,20 @@ +using System; + +namespace RateLimiter.Models; + +public sealed class RateLimitRule +{ + public RateLimitRule(TimeSpan period, long limit) + { + Period = period; + + if (limit < 1) + throw new ArgumentOutOfRangeException($"{nameof(Limit)} must be positive."); + + Limit = limit; + } + + public TimeSpan Period { get; } + + public long Limit { get; } +} diff --git a/RateLimiter/RateLimiter.csproj b/RateLimiter/RateLimiter.csproj index 19962f52..fc60dbbb 100644 --- a/RateLimiter/RateLimiter.csproj +++ b/RateLimiter/RateLimiter.csproj @@ -4,4 +4,13 @@ latest enable - \ No newline at end of file + + + + + + + + + + \ No newline at end of file diff --git a/RateLimiter/RateLimiterServiceCollectionExtensions.cs b/RateLimiter/RateLimiterServiceCollectionExtensions.cs new file mode 100644 index 00000000..07d6a5e3 --- /dev/null +++ b/RateLimiter/RateLimiterServiceCollectionExtensions.cs @@ -0,0 +1,40 @@ +using Microsoft.Extensions.DependencyInjection; +using RateLimiter.Services; +using RateLimiter.Services.Interfaces; + +namespace RateLimiter; + +public static class RateLimiterServiceCollectionExtensions +{ + /// + /// Adds an implementation of IRateLimitStorageService with memory cache + /// + public static IServiceCollection AddRateLimitingServiceWithMemoryCache(this IServiceCollection services) + { + return services + .AddRateLimitingServiceCore() + .AddMemoryCache() + .AddSingleton(); + } + + /// + /// Adds an implementation of IRateLimitStorageService with distributed cache + /// + public static IServiceCollection AddRateLimitingServiceWithDistributedCache(this IServiceCollection services) + { + return services + .AddRateLimitingServiceCore() + .AddSingleton(); + } + + /// + /// Adds an implementation of IRateLimitStorageService without IRateLimitStorageService implementation + /// + /// Additional needs to add implementation of IRateLimitStorageService + public static IServiceCollection AddRateLimitingServiceCore(this IServiceCollection services) + { + return services + .AddSingleton() + .AddSingleton(); + } +} diff --git a/RateLimiter/Services/DefaultDateTimeProvider.cs b/RateLimiter/Services/DefaultDateTimeProvider.cs new file mode 100644 index 00000000..de80ad69 --- /dev/null +++ b/RateLimiter/Services/DefaultDateTimeProvider.cs @@ -0,0 +1,9 @@ +using System; +using RateLimiter.Services.Interfaces; + +namespace RateLimiter.Services; + +internal sealed class DefaultDateTimeProvider : IDateTimeProvider +{ + public DateTime UtcNow => DateTime.UtcNow; +} diff --git a/RateLimiter/Services/DistributedCacheRateLimitStorageService.cs b/RateLimiter/Services/DistributedCacheRateLimitStorageService.cs new file mode 100644 index 00000000..accc109d --- /dev/null +++ b/RateLimiter/Services/DistributedCacheRateLimitStorageService.cs @@ -0,0 +1,29 @@ +using System; +using System.Text.Json; +using Microsoft.Extensions.Caching.Distributed; +using RateLimiter.Models; +using RateLimiter.Services.Interfaces; + +namespace RateLimiter.Services; + +internal sealed class DistributedCacheRateLimitStorageService : IRateLimitStorageService +{ + private readonly IDistributedCache _memoryCache; + + public DistributedCacheRateLimitStorageService(IDistributedCache memoryCache) => _memoryCache = memoryCache; + + public void Set(string id, RateLimitCounter counter, TimeSpan expirationTime) + => _memoryCache.SetString(id, JsonSerializer.Serialize(counter), new DistributedCacheEntryOptions().SetAbsoluteExpiration(expirationTime)); + + public bool Exists(string id) => !string.IsNullOrEmpty(_memoryCache.GetString(id)); + + public RateLimitCounter? Get(string id) + { + var stored = _memoryCache.GetString(id); + return !string.IsNullOrEmpty(stored) + ? JsonSerializer.Deserialize(stored) + : null; + } + + public void Remove(string id) => _memoryCache.Remove(id); +} diff --git a/RateLimiter/Services/Interfaces/IDateTimeProvider.cs b/RateLimiter/Services/Interfaces/IDateTimeProvider.cs new file mode 100644 index 00000000..0e7471ba --- /dev/null +++ b/RateLimiter/Services/Interfaces/IDateTimeProvider.cs @@ -0,0 +1,8 @@ +using System; + +namespace RateLimiter.Services.Interfaces; + +internal interface IDateTimeProvider +{ + DateTime UtcNow { get; } +} diff --git a/RateLimiter/Services/Interfaces/IRateLimitStorageService.cs b/RateLimiter/Services/Interfaces/IRateLimitStorageService.cs new file mode 100644 index 00000000..4ef32502 --- /dev/null +++ b/RateLimiter/Services/Interfaces/IRateLimitStorageService.cs @@ -0,0 +1,19 @@ +using System; +using RateLimiter.Models; + +namespace RateLimiter.Services.Interfaces; + +/// +/// Defines a storage for keeping of rate limiting data. +/// +/// Concrete classes should be based on solutions with excellent performance, such as in-memory solutions. +internal interface IRateLimitStorageService +{ + bool Exists(string id); + + RateLimitCounter? Get(string id); + + void Remove(string id); + + void Set(string id, RateLimitCounter counter, TimeSpan expirationTime); +} diff --git a/RateLimiter/Services/MemoryCacheRateLimitStorageService.cs b/RateLimiter/Services/MemoryCacheRateLimitStorageService.cs new file mode 100644 index 00000000..6981ec83 --- /dev/null +++ b/RateLimiter/Services/MemoryCacheRateLimitStorageService.cs @@ -0,0 +1,22 @@ +using System; +using Microsoft.Extensions.Caching.Memory; +using RateLimiter.Models; +using RateLimiter.Services.Interfaces; + +namespace RateLimiter.Services; + +internal sealed class MemoryCacheRateLimitStorageService : IRateLimitStorageService +{ + private readonly IMemoryCache _memoryCache; + + public MemoryCacheRateLimitStorageService(IMemoryCache memoryCache) => _memoryCache = memoryCache; + + public void Set(string id, RateLimitCounter counter, TimeSpan expirationTime) + => _memoryCache.Set(id, counter, new MemoryCacheEntryOptions().SetAbsoluteExpiration(expirationTime)); + + public bool Exists(string id) => _memoryCache.TryGetValue(id, out RateLimitCounter counter); + + public RateLimitCounter? Get(string id) => _memoryCache.TryGetValue(id, out RateLimitCounter counter) ? counter : null; + + public void Remove(string id) => _memoryCache.Remove(id); +} diff --git a/RateLimiter/Services/RateLimitingService.cs b/RateLimiter/Services/RateLimitingService.cs new file mode 100644 index 00000000..657e0d1e --- /dev/null +++ b/RateLimiter/Services/RateLimitingService.cs @@ -0,0 +1,120 @@ +using System; +using RateLimiter.Extensions; +using RateLimiter.Models; +using RateLimiter.Services.Interfaces; + +namespace RateLimiter.Services; + +/// +/// An implementation of IRateLimitingService +/// +/// +/// Based on https://github.com/ThreeMammals/Ocelot/blob/develop/src/Ocelot/RateLimiting/RateLimiting.cs +/// Can be used in middleware, like in original ocelot repo: https://github.com/ThreeMammals/Ocelot/blob/develop/src/Ocelot/RateLimiting/Middleware/RateLimitingMiddleware.cs +/// Or inside an action filter +/// +internal sealed class RateLimitingService : IRateLimitingService +{ + private static readonly TimeSpan DefaultRetryPeriod = TimeSpan.FromSeconds(1); + private static readonly object ProcessLocker = new(); + + private readonly IRateLimitStorageService _storageService; + private readonly IDateTimeProvider _dateTimeProvider; + + public RateLimitingService( + IRateLimitStorageService storageService, + IDateTimeProvider dateTimeProvider) + { + _storageService = storageService ?? throw new ArgumentNullException(nameof(storageService)); + _dateTimeProvider = dateTimeProvider ?? throw new ArgumentNullException(nameof(dateTimeProvider)); + } + + public RateLimitCounter GetRateLimitCounter(ClientRequestIdentity identity, RateLimitRule rule) + { + RateLimitCounter counter; + var counterId = identity.GetStorageKey(rule.Period); + + // Serial reads/writes from/to the storage which must be thread safe + lock (ProcessLocker) + { + var entry = _storageService.Get(counterId); + counter = Count(entry, rule); + var expiration = rule.Period; // default expiration is set for the Period value + + if (counter.TotalRequests > rule.Limit) + { + var retryAfter = RetryAfter(counter, rule); // the calculation depends on the counter returned from CountRequests + + if (retryAfter > TimeSpan.Zero) + { + // Rate Limit exceeded, ban period is active + expiration = rule.Period; // current state should expire in the storage after ban period + } + else + { + // Ban period elapsed, start counting + _storageService.Remove(counterId); // the store can delete the element on its own using an expiration mechanism, but let's force the element to be deleted + counter = new RateLimitCounter(DateTime.UtcNow, null, 1); + } + } + + _storageService.Set(counterId, counter, expiration); + } + + return counter; + } + + private RateLimitCounter Count(RateLimitCounter? entry, RateLimitRule rule) + { + var now = _dateTimeProvider.UtcNow; + if (!entry.HasValue) // no entry, start counting + { + return new RateLimitCounter(now, null, 1); // current request is the 1st one + } + + var counter = entry.Value; + var total = counter.TotalRequests + 1; // increment request count + var startedAt = counter.StartedAt; + if (startedAt + rule.Period >= now) // counting Period is active + { + var exceededAt = total >= rule.Limit && !counter.ExceededAt.HasValue // current request number equals to the limit + ? now // the exceeding moment is now, the next request will fail but the current one doesn't + : counter.ExceededAt; + return new RateLimitCounter(startedAt, exceededAt, total); // deep copy + } + + var wasExceededAt = counter.ExceededAt; + return wasExceededAt + rule.Period >= now // ban PeriodTimespan is active + ? new RateLimitCounter(startedAt, wasExceededAt, total) // still count + : new RateLimitCounter(now, null, 1); // Ban PeriodTimespan elapsed, start counting NOW! + } + + private TimeSpan RetryAfter(RateLimitCounter counter, RateLimitRule rule) + { + var periodTimespan = rule.Period < DefaultRetryPeriod + ? DefaultRetryPeriod // allow values which are greater or equal to 1 second + : rule.Period; // good value + + var now = _dateTimeProvider.UtcNow; + + if (counter.StartedAt + rule.Period >= now) // counting Period is active + { + return counter.TotalRequests < rule.Limit + ? TimeSpan.Zero // happy path, no need to retry, current request is valid + : counter.ExceededAt.HasValue + ? periodTimespan - (now - counter.ExceededAt.Value) // minus seconds past + : periodTimespan; // exceeding not yet detected -> let's ban for whole period + } + + if (counter.ExceededAt.HasValue // limit exceeding was happen + && counter.ExceededAt + periodTimespan >= now) // ban PeriodTimespan is active + { + var startedAt = counter.ExceededAt.Value; // ban period was started at + var secondsPast = now - startedAt; + var retryAfter = periodTimespan - secondsPast; + return retryAfter; // it can be negative, which means the wait in PeriodTimespan seconds has ended + } + + return TimeSpan.Zero; + } +} diff --git a/Samples/SimpleSample/Controllers/WeatherForecastController.cs b/Samples/SimpleSample/Controllers/WeatherForecastController.cs new file mode 100644 index 00000000..37ef94fc --- /dev/null +++ b/Samples/SimpleSample/Controllers/WeatherForecastController.cs @@ -0,0 +1,37 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.Logging; +using SimpleSample.Filters; + +namespace SimpleSample.Controllers; +[ApiController] +[Route("[controller]")] +public class WeatherForecastController : ControllerBase +{ + private static readonly string[] Summaries = new[] + { + "Freezing", "Bracing", "Chilly", "Cool", "Mild", "Warm", "Balmy", "Hot", "Sweltering", "Scorching" + }; + + private readonly ILogger _logger; + + public WeatherForecastController(ILogger logger) + { + _logger = logger; + } + + [HttpGet] + [TypeFilter(typeof(RateLimitActionFilter))] + public IEnumerable Get() + { + return Enumerable.Range(1, 5).Select(index => new WeatherForecast + { + Date = DateTime.Now.AddDays(index), + TemperatureC = Random.Shared.Next(-20, 55), + Summary = Summaries[Random.Shared.Next(Summaries.Length)] + }) + .ToArray(); + } +} diff --git a/Samples/SimpleSample/Filters/RateLimitActionFilter.cs b/Samples/SimpleSample/Filters/RateLimitActionFilter.cs new file mode 100644 index 00000000..0f3dac46 --- /dev/null +++ b/Samples/SimpleSample/Filters/RateLimitActionFilter.cs @@ -0,0 +1,91 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Filters; +using RateLimiter; +using RateLimiter.Models; + +namespace SimpleSample.Filters; + +internal sealed class RateLimitActionFilter : ActionFilterAttribute +{ + private static readonly Dictionary> RouteRateLimitDictionary = + new Dictionary>() + { + { + "/weatherforecast", + new List() + { + new RateLimitRule(TimeSpan.FromMinutes(1), 5), + new RateLimitRule(TimeSpan.FromSeconds(10), 1), + } + }, + }; + + private readonly IRateLimitingService _rateLimitingService; + + public RateLimitActionFilter( + IRateLimitingService rateLimitingService) + { + _rateLimitingService = rateLimitingService; + } + + public override void OnActionExecuting(ActionExecutingContext context) + { + var request = context.HttpContext.Request; + if (RouteRateLimitDictionary.TryGetValue(request.Path, out var rules) == false) + { + return; + } + + var clientId = GetClientId(request); + + if (clientId == null) + { + context.Result = new StatusCodeResult((int)HttpStatusCode.TooManyRequests); + } + + var identity = new ClientRequestIdentity(clientId!, request.Path, request.Method); + + foreach (var rule in rules) + { + var counter = _rateLimitingService.GetRateLimitCounter(identity, rule); + + SetRateLimitHeaders(context.HttpContext.Response, counter, rule); + + if (counter.TotalRequests > rule.Limit) + { + context.Result = new StatusCodeResult((int)HttpStatusCode.TooManyRequests); + } + } + } + + private static string? GetClientId(HttpRequest? request) + { + return request?.Headers["X-ClientId"].FirstOrDefault(); + } + + private static void SetRateLimitHeaders(HttpResponse response, RateLimitCounter counter, RateLimitRule rule) + { + long remainingAttemptes = 0; + DateTime reset; + + if (counter.TotalRequests > rule.Limit) + { + remainingAttemptes = rule.Limit; + reset = DateTime.UtcNow + rule.Period; + } + else + { + remainingAttemptes = rule.Limit - counter.TotalRequests; + reset = counter.StartedAt + rule.Period; + } + + response.Headers["X-Rate-Limit-Limit"] = rule.Period.ToString(); + response.Headers["X-Rate-Limit-Remaining"] = remainingAttemptes.ToString(); + response.Headers["X-Rate-Limit-Reset"] = reset.ToString(); + } +} diff --git a/Samples/SimpleSample/Program.cs b/Samples/SimpleSample/Program.cs new file mode 100644 index 00000000..65af73c1 --- /dev/null +++ b/Samples/SimpleSample/Program.cs @@ -0,0 +1,20 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using RateLimiter; + +var builder = WebApplication.CreateBuilder(args); + +// Add services to the container. + +builder.Services.AddControllers(); +builder.Services.AddRateLimitingServiceWithMemoryCache(); + +var app = builder.Build(); + +// Configure the HTTP request pipeline. + +app.UseAuthorization(); + +app.MapControllers(); + +app.Run(); diff --git a/Samples/SimpleSample/Properties/launchSettings.json b/Samples/SimpleSample/Properties/launchSettings.json new file mode 100644 index 00000000..1ba116c9 --- /dev/null +++ b/Samples/SimpleSample/Properties/launchSettings.json @@ -0,0 +1,31 @@ +{ + "$schema": "https://json.schemastore.org/launchsettings.json", + "iisSettings": { + "windowsAuthentication": false, + "anonymousAuthentication": true, + "iisExpress": { + "applicationUrl": "http://localhost:50054", + "sslPort": 0 + } + }, + "profiles": { + "SimpleSample": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": true, + "launchUrl": "weatherforecast", + "applicationUrl": "http://localhost:5175", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + }, + "IIS Express": { + "commandName": "IISExpress", + "launchBrowser": true, + "launchUrl": "weatherforecast", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + } + } +} diff --git a/Samples/SimpleSample/SimpleSample.csproj b/Samples/SimpleSample/SimpleSample.csproj new file mode 100644 index 00000000..0a12a79d --- /dev/null +++ b/Samples/SimpleSample/SimpleSample.csproj @@ -0,0 +1,12 @@ + + + + net6.0 + enable + + + + + + + diff --git a/Samples/SimpleSample/WeatherForecast.cs b/Samples/SimpleSample/WeatherForecast.cs new file mode 100644 index 00000000..38e88401 --- /dev/null +++ b/Samples/SimpleSample/WeatherForecast.cs @@ -0,0 +1,14 @@ +using System; + +namespace SimpleSample; + +public class WeatherForecast +{ + public DateTime Date { get; set; } + + public int TemperatureC { get; set; } + + public int TemperatureF => 32 + (int)(TemperatureC / 0.5556); + + public string? Summary { get; set; } +} diff --git a/Samples/SimpleSample/appsettings.Development.json b/Samples/SimpleSample/appsettings.Development.json new file mode 100644 index 00000000..0c208ae9 --- /dev/null +++ b/Samples/SimpleSample/appsettings.Development.json @@ -0,0 +1,8 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + } +} diff --git a/Samples/SimpleSample/appsettings.json b/Samples/SimpleSample/appsettings.json new file mode 100644 index 00000000..10f68b8c --- /dev/null +++ b/Samples/SimpleSample/appsettings.json @@ -0,0 +1,9 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + }, + "AllowedHosts": "*" +}