diff --git a/caddy/app.go b/caddy/app.go index 540f13e..8f8c1a2 100644 --- a/caddy/app.go +++ b/caddy/app.go @@ -19,6 +19,7 @@ package caddy import ( "errors" + "fmt" "log/slog" outline_prometheus "github.com/Jigsaw-Code/outline-ss-server/prometheus" @@ -30,9 +31,14 @@ import ( const outlineModuleName = "outline" func init() { + replayCache := outline.NewReplayCache(0) caddy.RegisterModule(ModuleRegistration{ - ID: outlineModuleName, - New: func() caddy.Module { return new(OutlineApp) }, + ID: outlineModuleName, + New: func() caddy.Module { + app := new(OutlineApp) + app.ReplayCache = replayCache + return app + }, }) } @@ -65,8 +71,9 @@ func (app *OutlineApp) Provision(ctx caddy.Context) error { app.logger.Info("provisioning app instance") if app.ShadowsocksConfig != nil { - // TODO: Persist replay cache across config reloads. - app.ReplayCache = outline.NewReplayCache(app.ShadowsocksConfig.ReplayHistory) + if err := app.ReplayCache.Resize(app.ShadowsocksConfig.ReplayHistory); err != nil { + return fmt.Errorf("failed to configure replay history with capacity %d: %v", app.ShadowsocksConfig.ReplayHistory, err) + } } if err := app.defineMetrics(); err != nil { diff --git a/caddy/shadowsocks_handler.go b/caddy/shadowsocks_handler.go index c348417..a0c4874 100644 --- a/caddy/shadowsocks_handler.go +++ b/caddy/shadowsocks_handler.go @@ -31,7 +31,7 @@ const ssModuleName = "layer4.handlers.shadowsocks" func init() { caddy.RegisterModule(ModuleRegistration{ - ID: ssModuleName, + ID: ssModuleName, New: func() caddy.Module { return new(ShadowsocksHandler) }, }) } diff --git a/service/replay.go b/service/replay.go index 818fde0..27b3d12 100644 --- a/service/replay.go +++ b/service/replay.go @@ -16,6 +16,7 @@ package service import ( "encoding/binary" + "errors" "sync" ) @@ -92,7 +93,7 @@ func (c *ReplayCache) Add(id string, salt []byte) bool { return false } _, inArchive := c.archive[hash] - if len(c.active) == c.capacity { + if len(c.active) >= c.capacity { // Discard the archive and move active to archive. c.archive = c.active c.active = make(map[uint32]empty, c.capacity) @@ -100,3 +101,17 @@ func (c *ReplayCache) Add(id string, salt []byte) bool { c.active[hash] = empty{} return !inArchive } + +// Resize adjusts the capacity of the ReplayCache. +func (c *ReplayCache) Resize(capacity int) error { + if capacity > MaxCapacity { + return errors.New("ReplayCache capacity would result in too many false positives") + } + c.mutex.Lock() + defer c.mutex.Unlock() + c.capacity = capacity + // NOTE: The active handshakes and archive lists are not explicitly shrunk. + // Their sizes will naturally adjust as new handshakes are added and the cache + // adheres to the updated capacity. + return nil +} diff --git a/service/replay_test.go b/service/replay_test.go index c0187c0..6cfd98a 100644 --- a/service/replay_test.go +++ b/service/replay_test.go @@ -17,6 +17,9 @@ package service import ( "encoding/binary" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const keyID = "the key" @@ -91,6 +94,81 @@ func TestReplayCache_Archive(t *testing.T) { } } +func TestReplayCache_Resize(t *testing.T) { + t.Run("Smaller resizes active and archive maps", func(t *testing.T) { + salts := makeSalts(10) + cache := NewReplayCache(5) + for _, s := range salts { + cache.Add(keyID, s) + } + + err := cache.Resize(3) + + require.NoError(t, err) + assert.Equal(t, cache.capacity, 3, "Expected capacity to be updated") + + // Adding a new salt should trigger a shrinking of the active map as it hits the new + // capacity immediately. + cache.Add(keyID, salts[0]) + assert.Len(t, cache.active, 1, "Expected active handshakes length to have shrunk") + assert.Len(t, cache.archive, 5, "Expected archive handshakes length to not have shrunk") + + // Adding more new salts should eventually trigger a shrinking of the archive map as well, + // when the shrunken active map gets moved to the archive. + for _, s := range salts { + cache.Add(keyID, s) + } + assert.Len(t, cache.archive, 3, "Expected archive handshakes length to have shrunk") + }) + + t.Run("Larger resizes active and archive maps", func(t *testing.T) { + salts := makeSalts(10) + cache := NewReplayCache(5) + for _, s := range salts { + cache.Add(keyID, s) + } + + err := cache.Resize(10) + + require.NoError(t, err) + assert.Equal(t, cache.capacity, 10, "Expected capacity to be updated") + assert.Len(t, cache.active, 5, "Expected active handshakes length not to have changed") + assert.Len(t, cache.archive, 5, "Expected archive handshakes length not to have changed") + }) + + t.Run("Still detect salts", func(t *testing.T) { + salts := makeSalts(10) + cache := NewReplayCache(5) + for _, s := range salts { + cache.Add(keyID, s) + } + + cache.Resize(10) + + for _, s := range salts { + if cache.Add(keyID, s) { + t.Error("Should still be able to detect the salts after resizing") + } + } + + cache.Resize(3) + + for _, s := range salts { + if cache.Add(keyID, s) { + t.Error("Should still be able to detect the salts after resizing") + } + } + }) + + t.Run("Exceeding maximum capacity", func(t *testing.T) { + cache := &ReplayCache{} + + err := cache.Resize(MaxCapacity + 1) + + require.Error(t, err) + }) +} + // Benchmark to determine the memory usage of ReplayCache. // Note that NewReplayCache only allocates the active set, // so the eventual memory usage will be roughly double.