From 59450bf156a3a1a3234c426c2d3251c837ab9599 Mon Sep 17 00:00:00 2001 From: Victor Elias Date: Wed, 20 Mar 2024 23:48:37 -0300 Subject: [PATCH] eth/watchers: Add tests for pricefeedwatcher --- eth/watchers/pricefeedwatcher.go | 39 +++-- eth/watchers/pricefeedwatcher_test.go | 202 ++++++++++++++++++++++++++ 2 files changed, 221 insertions(+), 20 deletions(-) diff --git a/eth/watchers/pricefeedwatcher.go b/eth/watchers/pricefeedwatcher.go index c6fb489ebd..2975f1ab40 100644 --- a/eth/watchers/pricefeedwatcher.go +++ b/eth/watchers/pricefeedwatcher.go @@ -13,12 +13,13 @@ import ( const ( priceUpdateMaxRetries = 5 priceUpdateBaseRetryDelay = 30 * time.Second + priceUpdatePeriod = 1 * time.Hour ) type PriceFeedWatcher struct { - ctx context.Context + ctx context.Context + baseRetryDelay time.Duration - updatePeriod time.Duration priceFeed eth.PriceFeedEthClient currencyBase, currencyQuote string @@ -26,11 +27,7 @@ type PriceFeedWatcher struct { priceUpdated chan eth.PriceData } -func NewPriceFeedWatcher(ctx context.Context, rpcUrl, priceFeedAddr string, updatePeriod time.Duration) (*PriceFeedWatcher, error) { - if updatePeriod <= 0 { - updatePeriod = 1 * time.Hour - } - +func NewPriceFeedWatcher(ctx context.Context, rpcUrl, priceFeedAddr string) (*PriceFeedWatcher, error) { priceFeed, err := eth.NewPriceFeedEthClient(ctx, rpcUrl, priceFeedAddr) if err != nil { return nil, fmt.Errorf("failed to create price feed client: %w", err) @@ -47,19 +44,25 @@ func NewPriceFeedWatcher(ctx context.Context, rpcUrl, priceFeedAddr string, upda } w := &PriceFeedWatcher{ - ctx: ctx, - updatePeriod: updatePeriod, - priceFeed: priceFeed, - currencyBase: currencyFrom, - currencyQuote: currencyTo, - priceUpdated: make(chan eth.PriceData, 1), + ctx: ctx, + baseRetryDelay: priceUpdateBaseRetryDelay, + priceFeed: priceFeed, + currencyBase: currencyFrom, + currencyQuote: currencyTo, + priceUpdated: make(chan eth.PriceData, 1), } err = w.updatePrice() if err != nil { return nil, fmt.Errorf("failed to update price: %w", err) } - go w.watch() + + go func() { + ctx, cancel := context.WithCancel(w.ctx) + defer cancel() + ticker := newTruncatedTicker(ctx, priceUpdatePeriod) + w.watch(ctx, ticker) + }() return w, nil } @@ -95,17 +98,13 @@ func (w *PriceFeedWatcher) updatePrice() error { return nil } -func (w *PriceFeedWatcher) watch() { - ctx, cancel := context.WithCancel(w.ctx) - defer cancel() - ticker := newTruncatedTicker(ctx, w.updatePeriod) - +func (w *PriceFeedWatcher) watch(ctx context.Context, ticker <-chan time.Time) { for { select { case <-w.ctx.Done(): return case <-ticker: - attempt, retryDelay := 1, priceUpdateBaseRetryDelay + attempt, retryDelay := 1, w.baseRetryDelay for { err := w.updatePrice() if err == nil { diff --git a/eth/watchers/pricefeedwatcher_test.go b/eth/watchers/pricefeedwatcher_test.go index b46d691c9c..26512102c9 100644 --- a/eth/watchers/pricefeedwatcher_test.go +++ b/eth/watchers/pricefeedwatcher_test.go @@ -2,12 +2,214 @@ package watchers import ( "context" + "errors" + "math/big" "testing" "time" + "github.com/livepeer/go-livepeer/eth" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) +type mockPriceFeedEthClient struct { + mock.Mock +} + +func (m *mockPriceFeedEthClient) FetchPriceData() (eth.PriceData, error) { + args := m.Called() + return args.Get(0).(eth.PriceData), args.Error(1) +} + +func (m *mockPriceFeedEthClient) Description() (string, error) { + args := m.Called() + return args.String(0), args.Error(1) +} + +func TestPriceFeedWatcher_UpdatePrice(t *testing.T) { + priceFeedMock := new(mockPriceFeedEthClient) + defer priceFeedMock.AssertExpectations(t) + + priceData := eth.PriceData{ + RoundID: 10, + Price: big.NewRat(3, 2), + UpdatedAt: time.Now(), + } + priceFeedMock.On("FetchPriceData").Return(priceData, nil).Once() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + w := &PriceFeedWatcher{ + ctx: ctx, + priceFeed: priceFeedMock, + currencyBase: "ETH", + currencyQuote: "USD", + priceUpdated: make(chan eth.PriceData, 1), + } + + require.NoError(t, w.updatePrice()) + require.Equal(t, priceData, w.current) + select { + case updatedPrice := <-w.priceUpdated: + require.Equal(t, priceData, updatedPrice) + case <-time.After(2 * time.Second): + t.Error("Updated price hasn't been received on channel") + } +} + +func TestPriceFeedWatcher_Watch(t *testing.T) { + require := require.New(t) + priceFeedMock := new(mockPriceFeedEthClient) + defer priceFeedMock.AssertExpectations(t) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + w := &PriceFeedWatcher{ + ctx: ctx, + priceFeed: priceFeedMock, + currencyBase: "ETH", + currencyQuote: "USD", + priceUpdated: make(chan eth.PriceData, 1), + } + + priceData := eth.PriceData{ + RoundID: 10, + Price: big.NewRat(9, 2), + UpdatedAt: time.Now(), + } + checkPriceUpdated := func() { + select { + case updatedPrice := <-w.priceUpdated: + require.Equal(priceData, updatedPrice) + require.Equal(priceData, w.current) + case <-time.After(1 * time.Second): + require.Fail("Updated price hasn't been received on channel in a timely manner") + } + priceFeedMock.AssertExpectations(t) + } + checkNoPriceUpdate := func() { + select { + case <-w.priceUpdated: + require.Fail("Unexpected price update given it hasn't changed") + case <-time.After(1 * time.Second): + // all good + } + priceFeedMock.AssertExpectations(t) + } + + // Start the watch loop + fakeTicker := make(chan time.Time, 10) + go func() { + w.watch(ctx, fakeTicker) + }() + + // First time should trigger an update + priceFeedMock.On("FetchPriceData").Return(priceData, nil).Once() + fakeTicker <- time.Now() + checkPriceUpdated() + + // Trigger a dummy update given price hasn't changed + priceFeedMock.On("FetchPriceData").Return(priceData, nil).Once() + fakeTicker <- time.Now() + checkNoPriceUpdate() + + // still shouldn't update given UpdatedAt stayed the same + priceData.Price = big.NewRat(1, 1) + priceFeedMock.On("FetchPriceData").Return(priceData, nil).Once() + fakeTicker <- time.Now() + checkNoPriceUpdate() + + // bump the UpdatedAt time to trigger an update + priceData.UpdatedAt = priceData.UpdatedAt.Add(1 * time.Minute) + priceFeedMock.On("FetchPriceData").Return(priceData, nil).Once() + fakeTicker <- time.Now() + checkPriceUpdated() + + priceData.UpdatedAt = priceData.UpdatedAt.Add(1 * time.Hour) + priceData.Price = big.NewRat(3, 2) + priceFeedMock.On("FetchPriceData").Return(priceData, nil).Once() + fakeTicker <- time.Now() + checkPriceUpdated() +} + +func TestPriceFeedWatcher_WatchErrorRetries(t *testing.T) { + priceFeedMock := new(mockPriceFeedEthClient) + defer priceFeedMock.AssertExpectations(t) + + // First 4 calls should fail then succeed on the 5th + for i := 0; i < 4; i++ { + priceFeedMock.On("FetchPriceData").Return(eth.PriceData{}, errors.New("error")).Once() + } + priceData := eth.PriceData{ + RoundID: 10, + Price: big.NewRat(3, 2), + UpdatedAt: time.Now(), + } + priceFeedMock.On("FetchPriceData").Return(priceData, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + w := &PriceFeedWatcher{ + ctx: ctx, + baseRetryDelay: 5 * time.Millisecond, + priceFeed: priceFeedMock, + currencyBase: "ETH", + currencyQuote: "USD", + priceUpdated: make(chan eth.PriceData, 1), + } + + // Start watch loop + fakeTicker := make(chan time.Time, 10) + go func() { + w.watch(ctx, fakeTicker) + }() + + fakeTicker <- time.Now() + select { + case updatedPrice := <-w.priceUpdated: + require.Equal(t, priceData, updatedPrice) + case <-time.After(2 * time.Second): + t.Error("Updated price hasn't been received on channel") + } +} + +func TestParseCurrencies(t *testing.T) { + t.Run("Valid currencies", func(t *testing.T) { + description := "ETH / USD" + currencyBase, currencyQuote, err := parseCurrencies(description) + + require.NoError(t, err) + require.Equal(t, "ETH", currencyBase) + require.Equal(t, "USD", currencyQuote) + }) + + t.Run("Missing separator", func(t *testing.T) { + description := "ETHUSD" + _, _, err := parseCurrencies(description) + + require.Error(t, err) + require.Contains(t, err.Error(), "aggregator description must be in the format 'FROM / TO'") + }) + + t.Run("Extra spaces", func(t *testing.T) { + description := " ETH / USD " + currencyBase, currencyQuote, err := parseCurrencies(description) + + require.NoError(t, err) + require.Equal(t, "ETH", currencyBase) + require.Equal(t, "USD", currencyQuote) + }) + + t.Run("Lowercase currency", func(t *testing.T) { + description := "eth / usd" + currencyBase, currencyQuote, err := parseCurrencies(description) + + require.NoError(t, err) + require.Equal(t, "eth", currencyBase) + require.Equal(t, "usd", currencyQuote) + }) +} + func TestNewTruncatedTicker(t *testing.T) { testTimeout := time.After(10 * time.Second) require := require.New(t)