From 3b2fdc00ecbdc5cad6b957b36922f58e1c5852ce Mon Sep 17 00:00:00 2001 From: Tim Voronov Date: Sat, 29 Jun 2024 18:04:07 -0400 Subject: [PATCH] Refactored API --- README.md | 46 +++++++++++++++++++++++++++++++++++----------- throttle.go | 36 ++++++++++++++---------------------- throttle_test.go | 44 +++++++++++++++----------------------------- transport.go | 27 +++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 62 deletions(-) create mode 100644 transport.go diff --git a/README.md b/README.md index 56b9d17..4fd755f 100644 --- a/README.md +++ b/README.md @@ -22,25 +22,25 @@ import ( type ApiClient struct { transport *http.Client - throttler *throttle.Throttler[*http.Response] + throttler *throttle.Throttler } func NewApiClient(rps uint64) *ApiClient { return &ApiClient{ transport: &http.Client{}, - throttler: throttle.New[*http.Response](rps), + throttler: throttle.New(rps), } } func (c *ApiClient) Do(ctx context.Context, req *http.Request) (*http.Response, error) { - return c.throttler.Do(func() (*http.Response, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - return c.transport.Do(req) - } - }) + c.throttler.Acquire() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + return c.transport.Do(req) + } } ``` @@ -72,6 +72,30 @@ func (c *MyClock) Sleep(dur time.Duration) { } func main() { - throttler := throttle.New[any](10, throttle.WithClock(&MyClock{time.Millisecond * 250})) + throttler := throttle.New(10, throttle.WithClock(&MyClock{time.Millisecond * 250})) +} +``` + +## Helpers +### RoundTripper +The package contains a helper that wraps the standard `http.RoundTripper` interface and provides a throttling mechanism. + +```go +package myapp + +import ( + "context" + "net/http" + "github.com/ziflex/throttle" +) + +func main() { + transport := &http.Transport{} + client := &http.Client{ + Transport: throttle.NewRoundTripper(transport, 10), + } + + req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil) + client.Do(req) } ``` \ No newline at end of file diff --git a/throttle.go b/throttle.go index 2916d01..95ffb59 100644 --- a/throttle.go +++ b/throttle.go @@ -7,42 +7,34 @@ import ( const windowSize = time.Second -type ( - // Fn represents a function that returns a value of type T and an error. - Fn[T any] func() (T, error) - - // Throttler manages the execution of operations so that they don't exceed a specified rate limit. - Throttler[T any] struct { - mu sync.Mutex - window time.Time - clock Clock - counter uint64 - limit uint64 - } -) +// Throttler manages the execution of operations so that they don't exceed a specified rate limit. +type Throttler struct { + mu sync.Mutex + window time.Time + clock Clock + counter uint64 + limit uint64 +} // New creates a new instance of Throttler with a specified limit. -func New[T any](limit uint64, setters ...Option) *Throttler[T] { +func New(limit uint64, setters ...Option) *Throttler { opts := buildOptions(setters) - return &Throttler[T]{ + return &Throttler{ limit: limit, clock: opts.clock, } } -// Do executes the provided function fn if the rate limit has not been reached. -// It ensures that the operation respects the throttling constraints. -func (t *Throttler[T]) Do(fn Fn[T]) (T, error) { +// Acquire blocks until the operation can be executed within the rate limit. +func (t *Throttler) Acquire() { t.mu.Lock() t.advance() t.mu.Unlock() - - return fn() } // advance updates the throttler state, advancing the window or incrementing the counter as necessary. -func (t *Throttler[T]) advance() { +func (t *Throttler) advance() { // pass through if t.limit == 0 { return @@ -87,7 +79,7 @@ func (t *Throttler[T]) advance() { } // reset starts a new window from the specified start time and resets the operation counter. -func (t *Throttler[T]) reset(window time.Time) { +func (t *Throttler) reset(window time.Time) { t.window = window t.counter = 1 } diff --git a/throttle_test.go b/throttle_test.go index aee6cc0..89b6aa8 100644 --- a/throttle_test.go +++ b/throttle_test.go @@ -43,7 +43,7 @@ func TestThrottler_Do_Consistent(t *testing.T) { for _, useCase := range useCases { t.Run(fmt.Sprintf("Consistent %d RPS within %d calls", useCase.Limit, useCase.Calls), func(t *testing.T) { calls := make(chan time.Time, useCase.Calls) - throttler := throttle.New[time.Time](useCase.Limit) + throttler := throttle.New(useCase.Limit) ts := time.Now() var wg sync.WaitGroup @@ -51,11 +51,8 @@ func TestThrottler_Do_Consistent(t *testing.T) { for range useCase.Calls { go func() { - res, _ := throttler.Do(func() (time.Time, error) { - return time.Now(), nil - }) - - calls <- res + throttler.Acquire() + calls <- time.Now() wg.Done() }() } @@ -152,7 +149,7 @@ func TestThrottler_Do_Sporadic(t *testing.T) { } calls := make(chan time.Time, buffer) - throttler := throttle.New[time.Time](useCase.Limit) + throttler := throttle.New(useCase.Limit) ts := time.Now() var wg sync.WaitGroup @@ -169,19 +166,13 @@ func TestThrottler_Do_Sporadic(t *testing.T) { } for range callNum { - res, _ := throttler.Do(func() (time.Time, error) { - if latency > 0 { - time.Sleep(latency) - } - - return time.Now(), nil - }) - - calls <- res + throttler.Acquire() - //ts := time.Now() + if latency > 0 { + time.Sleep(latency) + } - //fmt.Println(fmt.Sprintf("Call %dms", time.Since(ts).Milliseconds())) + calls <- time.Now() } wg.Done() @@ -301,7 +292,7 @@ func TestThrottler_Do_Parallel(t *testing.T) { for _, useCase := range useCases { t.Run(fmt.Sprintf("Parallel %d RPS", useCase.Limit), func(t *testing.T) { calls := make(chan time.Time, len(useCase.Calls)) - throttler := throttle.New[time.Time](useCase.Limit) + throttler := throttle.New(useCase.Limit) ts := time.Now() var wg sync.WaitGroup @@ -311,18 +302,13 @@ func TestThrottler_Do_Parallel(t *testing.T) { go func(latency time.Duration) { defer wg.Done() - // callTs := time.Now() - res, _ := throttler.Do(func() (time.Time, error) { - if latency > 0 { - time.Sleep(latency) - } + throttler.Acquire() - return time.Now(), nil - }) - - calls <- res + if latency > 0 { + time.Sleep(latency) + } - // fmt.Println(fmt.Sprintf("Call %dms", time.Since(callTs).Milliseconds())) + calls <- time.Now() }(tpl.Latency) } diff --git a/transport.go b/transport.go new file mode 100644 index 0000000..757cc75 --- /dev/null +++ b/transport.go @@ -0,0 +1,27 @@ +package throttle + +import ( + "net/http" +) + +type throttledRoundTripper struct { + transport http.RoundTripper + throttler *Throttler +} + +func (t *throttledRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { + t.throttler.Acquire() + + return t.transport.RoundTrip(request) +} + +func NewRoundTripper(transport http.RoundTripper, limit uint64, setters ...Option) http.RoundTripper { + return NewRoundTripperWith(transport, New(limit, setters...)) +} + +func NewRoundTripperWith(transport http.RoundTripper, throttler *Throttler) http.RoundTripper { + return &throttledRoundTripper{ + transport: transport, + throttler: throttler, + } +}