diff --git a/lib/pool/compression_cache_pool.go b/lib/pool/compression_cache_pool.go new file mode 100644 index 0000000..3c3e486 --- /dev/null +++ b/lib/pool/compression_cache_pool.go @@ -0,0 +1,94 @@ +// Copyright 2024 openGemini Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pool + +import ( + "bytes" + "compress/gzip" + "errors" + "runtime" + + "github.com/golang/snappy" + "github.com/klauspost/compress/zstd" +) + +var ( + gzipReaderPool = NewCachePool[*gzip.Reader](nil, 2*runtime.NumCPU()) + + snappyReaderPool = NewCachePool[*snappy.Reader](func() *snappy.Reader { + return snappy.NewReader(bytes.NewReader(nil)) + }, 2*runtime.NumCPU()) + + zstdDecoderPool = NewCachePool[*zstd.Decoder](func() *zstd.Decoder { + decoder, _ := zstd.NewReader(nil) + return decoder + }, 2*runtime.NumCPU()) +) + +func GetGzipReader(body []byte) (*gzip.Reader, error) { + // gzip reader not support new with nil writer + // so we need to create a new reader if pool is empty + if gzipReaderPool.AvailableOffers() == gzipReaderPool.Capacity() { + return gzip.NewReader(bytes.NewReader(body)) + } + reader := gzipReaderPool.Get() + if reader == nil { + return nil, errors.New("failed to get gzip reader") + } + err := reader.Reset(bytes.NewReader(body)) + if err != nil { + return nil, err + } + return reader, nil +} + +func PutGzipReader(reader *gzip.Reader) { + reader.Close() + gzipReaderPool.Put(reader) +} + +func GetSnappyReader(body []byte) (*snappy.Reader, error) { + reader := snappyReaderPool.Get() + if reader == nil { + return nil, errors.New("failed to get snappy reader") + } + reader.Reset(bytes.NewReader(body)) + + return reader, nil +} + +func PutSnappyReader(reader *snappy.Reader) { + snappyReaderPool.Put(reader) +} + +func GetZstdDecoder(body []byte) (*zstd.Decoder, error) { + decoder := zstdDecoderPool.Get() + if decoder == nil { + return nil, errors.New("failed to get zstd decoder") + } + err := decoder.Reset(bytes.NewReader(body)) + if err != nil { + return nil, err + } + return decoder, nil +} + +func PutZstdDecoder(decoder *zstd.Decoder) { + err := decoder.Reset(nil) + if err != nil { + return + } + zstdDecoderPool.Put(decoder) +} diff --git a/lib/pool/compression_cache_pool_test.go b/lib/pool/compression_cache_pool_test.go new file mode 100644 index 0000000..5b91bd7 --- /dev/null +++ b/lib/pool/compression_cache_pool_test.go @@ -0,0 +1,109 @@ +// Copyright 2024 openGemini Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pool + +import ( + "bytes" + "compress/gzip" + "io" + "testing" + + "github.com/golang/snappy" + "github.com/klauspost/compress/zstd" +) + +func TestGzipReaderPool(t *testing.T) { + data := []byte("test data") + var buf bytes.Buffer + writer := gzip.NewWriter(&buf) + _, err := writer.Write(data) + if err != nil { + t.Fatalf("failed to write gzip data: %v", err) + } + writer.Close() + + compressedData := buf.Bytes() + + reader, err := GetGzipReader(compressedData) + if err != nil { + t.Fatalf("failed to get gzip reader: %v", err) + } + + decompressedData, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("failed to read gzip data: %v", err) + } + + if !bytes.Equal(decompressedData, data) { + t.Errorf("expected %v, got %v", data, decompressedData) + } + + PutGzipReader(reader) +} + +func TestSnappyReaderPool(t *testing.T) { + data := []byte("test data") + var buf bytes.Buffer + + // Write data to buffer + writer := snappy.NewBufferedWriter(&buf) + _, err := writer.Write(data) + if err != nil { + t.Fatalf("failed to write snappy data: %v", err) + } + writer.Close() + + compressedData := buf.Bytes() + + reader, err := GetSnappyReader(compressedData) + if err != nil { + t.Fatalf("failed to get snappy reader: %v", err) + } + + decompressedData, err := io.ReadAll(reader) + if err != nil { + t.Fatalf("failed to read snappy data: %v", err) + } + + if !bytes.Equal(decompressedData, data) { + t.Errorf("expected %v, got %v", data, decompressedData) + } + + PutSnappyReader(reader) + +} + +func TestZstdDecoderPool(t *testing.T) { + data := []byte("test data") + encoder, _ := zstd.NewWriter(nil) + compressedData := encoder.EncodeAll(data, nil) + encoder.Close() + + decoder, err := GetZstdDecoder(compressedData) + if err != nil { + t.Fatalf("failed to get zstd decoder: %v", err) + } + + decompressedData, err := decoder.DecodeAll(compressedData, nil) + if err != nil { + t.Fatalf("failed to read zstd data: %v", err) + } + + if !bytes.Equal(decompressedData, data) { + t.Errorf("expected %v, got %v", data, decompressedData) + } + + PutZstdDecoder(decoder) +} diff --git a/lib/pool/pool.go b/lib/pool/pool.go index 7f3abb7..7c05437 100644 --- a/lib/pool/pool.go +++ b/lib/pool/pool.go @@ -18,34 +18,54 @@ import ( "sync" ) -type CachePool struct { - pool sync.Pool - size chan struct{} +type CachePool[T any] struct { + pool sync.Pool + capacityChan chan struct{} + newFunc func() T } -func NewCachePool(newFunc func() interface{}, maxSize int) *CachePool { - return &CachePool{ +func NewCachePool[T any](newFunc func() T, maxSize int) *CachePool[T] { + return &CachePool[T]{ pool: sync.Pool{ - New: newFunc, + New: func() interface{} { + if newFunc != nil { + return newFunc() + } + return nil + }, }, - size: make(chan struct{}, maxSize), + capacityChan: make(chan struct{}, maxSize), + newFunc: newFunc, } } -func (c *CachePool) Get() interface{} { +func (c *CachePool[T]) Get() T { select { - case c.size <- struct{}{}: - return c.pool.Get() + case c.capacityChan <- struct{}{}: + item := c.pool.Get() + if item == nil && c.newFunc != nil { + return c.newFunc() + } + return item.(T) default: - return c.pool.New() + var zero T + return zero } } -func (c *CachePool) Put(x interface{}) { +func (c *CachePool[T]) Put(x T) { select { - case <-c.size: + case <-c.capacityChan: c.pool.Put(x) default: // Pool is full, discard the item } } + +func (c *CachePool[T]) AvailableOffers() int { + return cap(c.capacityChan) - len(c.capacityChan) +} + +func (c *CachePool[T]) Capacity() int { + return cap(c.capacityChan) +} diff --git a/lib/pool/pool_test.go b/lib/pool/pool_test.go index 4d17268..50685a0 100644 --- a/lib/pool/pool_test.go +++ b/lib/pool/pool_test.go @@ -4,28 +4,28 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package pool import ( - "sync" "testing" ) func TestCachePool(t *testing.T) { // Create a new CachePool with a max size of 2 pool := NewCachePool(func() interface{} { - return new(int) + return new(struct{}) }, 2) // Get an item from the pool - item1 := pool.Get().(*int) + item1 := pool.Get().(*struct{}) if item1 == nil { t.Errorf("expected non-nil item, got nil") } @@ -34,7 +34,7 @@ func TestCachePool(t *testing.T) { pool.Put(item1) // Get another item from the pool - item2 := pool.Get().(*int) + item2 := pool.Get().(*struct{}) if item2 == nil { t.Errorf("expected non-nil item, got nil") } @@ -44,29 +44,29 @@ func TestCachePool(t *testing.T) { t.Errorf("expected the same item, got different items") } - // Put the item back into the pool + if pool.AvailableOffers() != 1 { + t.Errorf("The expected remaining capacity of the pool is 1, got %d", pool.AvailableOffers()) + } pool.Put(item2) - // Get two more items from the pool - item3 := pool.Get().(*int) - item4 := pool.Get().(*int) + item3 := pool.Get().(*struct{}) + if item3 == nil { + t.Errorf("expected non-nil item, got nil") + } - // Ensure the pool does not exceed its max size - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - pool.Put(item3) - pool.Put(item4) - }() + item4 := pool.Get().(*struct{}) + if item4 == nil { + t.Errorf("expected non-nil item, got nil") + } + + if pool.AvailableOffers() != 0 { + t.Errorf("The expected remaining capacity of the pool is 0, got %d", pool.AvailableOffers()) + } - wg.Wait() + pool.Put(item3) + pool.Put(item4) - // Ensure the pool size is correct - select { - case pool.size <- struct{}{}: - return - default: - t.Errorf("expected pool to be full, but it was not") + if pool.AvailableOffers() != 2 { + t.Errorf("The expected remaining capacity of the pool is 2, got %d", pool.AvailableOffers()) } } diff --git a/opengemini/query.go b/opengemini/query.go index 491d4c1..ffa8d74 100644 --- a/opengemini/query.go +++ b/opengemini/query.go @@ -15,8 +15,6 @@ package opengemini import ( - "bytes" - "compress/gzip" "encoding/json" "errors" "fmt" @@ -24,8 +22,7 @@ import ( "net/http" "time" - "github.com/golang/snappy" - "github.com/klauspost/compress/zstd" + compressionPool "github.com/openGemini/opengemini-client-go/lib/pool" "github.com/vmihailenco/msgpack/v5" ) @@ -175,11 +172,11 @@ func decompressBody(encoding string, body []byte) ([]byte, error) { } func decodeGzipBody(body []byte) ([]byte, error) { - decoder, err := gzip.NewReader(bytes.NewReader(body)) + decoder, err := compressionPool.GetGzipReader(body) if err != nil { return nil, errors.New("failed to create gzip decoder: " + err.Error()) } - defer decoder.Close() + defer compressionPool.PutGzipReader(decoder) decompressedBody, err := io.ReadAll(decoder) if err != nil { @@ -190,11 +187,11 @@ func decodeGzipBody(body []byte) ([]byte, error) { } func decodeZstdBody(compressedBody []byte) ([]byte, error) { - decoder, err := zstd.NewReader(nil) + decoder, err := compressionPool.GetZstdDecoder(compressedBody) if err != nil { return nil, errors.New("failed to create zstd decoder: " + err.Error()) } - defer decoder.Close() + defer compressionPool.PutZstdDecoder(decoder) decompressedBody, err := decoder.DecodeAll(compressedBody, nil) if err != nil { @@ -205,7 +202,11 @@ func decodeZstdBody(compressedBody []byte) ([]byte, error) { } func decodeSnappyBody(compressedBody []byte) ([]byte, error) { - reader := snappy.NewReader(bytes.NewReader(compressedBody)) + reader, err := compressionPool.GetSnappyReader(compressedBody) + if err != nil { + return nil, errors.New("failed to create snappy reader: " + err.Error()) + } + defer compressionPool.PutSnappyReader(reader) decompressedBody, err := io.ReadAll(reader) if err != nil { return nil, errors.New("failed to decompress snappy body: " + err.Error())