diff --git a/lang/channel/channel.go b/lang/channel/channel.go new file mode 100644 index 00000000..15bcc5b7 --- /dev/null +++ b/lang/channel/channel.go @@ -0,0 +1,360 @@ +// Copyright 2023 ByteDance Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package channel + +import ( + "container/list" + "runtime" + "sync" + "sync/atomic" + "time" +) + +const ( + defaultThrottleWindow = time.Millisecond * 100 + defaultSize = 0 +) + +// terminalSig is special item to define system message +var terminalSig interface{} = struct{}{} + +type item struct { + value interface{} + deadline time.Time +} + +// IsExpired check is item exceed deadline, zero means non-expired +func (i item) IsExpired() bool { + if i.deadline.IsZero() { + return false + } + return time.Now().After(i.deadline) +} + +// Option define channel Option +type Option func(c *channel) + +// Throttle define channel Throttle function +type Throttle func(c *channel) bool + +// WithSize define the size of channel. +// It conflicts with WithNonBlock option. +func WithSize(size int) Option { + return func(c *channel) { + // with non block mode, no need to change size + if !c.nonblock { + c.size = size + } + } +} + +// WithNonBlock will set channel to non-blocking Mode. +// The input channel will not block for any cases. +func WithNonBlock() Option { + return func(c *channel) { + c.nonblock = true + c.size = 1024 + } +} + +// WithTimeout sets the expiration time of each channel item. +// If the item not consumed in timeout duration, it will be aborted. +func WithTimeout(timeout time.Duration) Option { + return func(c *channel) { + c.timeout = timeout + } +} + +// WithTimeoutCallback sets callback function when item hit timeout. +func WithTimeoutCallback(timeoutCallback func(interface{})) Option { + return func(c *channel) { + c.timeoutCallback = timeoutCallback + } +} + +// WithThrottle sets both producerThrottle and consumerThrottle +// If producerThrottle throttled, it input channel will be blocked(if using blocking mode). +// If consumerThrottle throttled, it output channel will be blocked. +func WithThrottle(producerThrottle, consumerThrottle Throttle) Option { + return func(c *channel) { + if c.producerThrottle == nil { + c.producerThrottle = producerThrottle + } else { + prevChecker := c.producerThrottle + c.producerThrottle = func(c *channel) bool { + return prevChecker(c) && producerThrottle(c) + } + } + if c.consumerThrottle == nil { + c.consumerThrottle = consumerThrottle + } else { + prevChecker := c.consumerThrottle + c.consumerThrottle = func(c *channel) bool { + return prevChecker(c) && consumerThrottle(c) + } + } + } +} + +// WithThrottleWindow sets the interval time for throttle function checking. +func WithThrottleWindow(window time.Duration) Option { + return func(c *channel) { + c.throttleWindow = window + } +} + +// WithRateThrottle is a helper function to control producer and consumer process rate. +// produceRate and consumeRate mean how many item could be processed in one second, aka TPS. +func WithRateThrottle(produceRate, consumeRate int) Option { + // throttle function will be called sequentially + producedMax := uint64(produceRate) + consumedMax := uint64(consumeRate) + var producedBegin, consumedBegin uint64 + var producedTS, consumedTS int64 + return WithThrottle(func(c *channel) bool { + ts := time.Now().Unix() // in second + produced := atomic.LoadUint64(&c.produced) + if producedTS != ts { + // move to a new second, so store the current process as beginning value + producedBegin = produced + producedTS = ts + return false + } + // get the value of beginning + producedDiff := produced - producedBegin + return producedMax > 0 && producedMax < producedDiff + }, func(c *channel) bool { + ts := time.Now().Unix() // in second + consumed := atomic.LoadUint64(&c.consumed) + if consumedTS != ts { + // move to a new second, so store the current process as beginning value + consumedBegin = consumed + consumedTS = ts + return false + } + // get the value of beginning + consumedDiff := consumed - consumedBegin + return consumedMax > 0 && consumedMax < consumedDiff + }) +} + +var _ Channel = (*channel)(nil) + +type Channel interface { + // Input return a native chan for produce task + Input() chan interface{} + // Output return a native chan for consume task + Output() chan interface{} + // Len return the count of un-consumed tasks + Len() int + // Close will close the producer and consumer goroutines gracefully + Close() +} + +// channelWrapper use to detect user never hold the reference of channel object, and we need to close channel implicitly. +type channelWrapper struct { + Channel +} + +// channel implements a safe and feature-rich channel struct for the real world. +type channel struct { + size int + state int32 + producer chan interface{} + consumer chan interface{} + timeout time.Duration + timeoutCallback func(interface{}) + producerThrottle Throttle + consumerThrottle Throttle + throttleWindow time.Duration + // statistics + produced uint64 + consumed uint64 + // non blocking mode + nonblock bool + // buffer + buffer *list.List // TODO: use high perf queue to reduce GC here + bufferCond *sync.Cond + bufferLock sync.Mutex +} + +// New create a new channel. +func New(opts ...Option) Channel { + c := new(channel) + c.size = defaultSize + c.throttleWindow = defaultThrottleWindow + c.bufferCond = sync.NewCond(&c.bufferLock) + for _, opt := range opts { + opt(c) + } + c.producer = make(chan interface{}, c.size) + c.consumer = make(chan interface{}) + c.buffer = list.New() + + go c.produce() + go c.consume() + + // register finalizer for wrapper of channel + cw := &channelWrapper{c} + runtime.SetFinalizer(cw, func(obj *channelWrapper) { + // it's ok to call Close again if user already closed the channel + obj.Close() + }) + return cw +} + +// Close will close the producer and consumer goroutines gracefully +func (c *channel) Close() { + if !atomic.CompareAndSwapInt32(&c.state, 0, -1) { + return + } + // empty buffer + c.bufferLock.Lock() + c.buffer.Init() // clear + c.bufferLock.Unlock() + c.bufferCond.Broadcast() + c.producer <- terminalSig +} + +func (c *channel) isClosed() bool { + return atomic.LoadInt32(&c.state) < 0 +} + +// Input return a native chan for produce task +func (c *channel) Input() chan interface{} { + return c.producer +} + +// Output return a native chan for consume task +func (c *channel) Output() chan interface{} { + return c.consumer +} + +// Len return the count of un-consumed tasks. +func (c *channel) Len() int { + produced, consumed := atomic.LoadUint64(&c.produced), atomic.LoadUint64(&c.consumed) + l := produced - consumed + return int(l) +} + +// produce used to process input channel +func (c *channel) produce() { + capacity := c.size + if c.size == 0 { + capacity = 1 + } + for p := range c.producer { + // only check throttle function in blocking mode + if !c.nonblock { + c.throttling(c.producerThrottle) + } + + // produced + atomic.AddUint64(&c.produced, 1) + // prepare item + it := item{value: p} + if c.timeout > 0 { + it.deadline = time.Now().Add(c.timeout) + } + // enqueue buffer + c.bufferLock.Lock() + c.enqueueBuffer(it) + c.bufferCond.Signal() + if !c.nonblock { + for c.buffer.Len() >= capacity { + c.bufferCond.Wait() + } + } + c.bufferLock.Unlock() + + if p == terminalSig { // graceful shutdown + close(c.producer) + return + } + } +} + +// consume used to process output channel +func (c *channel) consume() { + for { + // check throttle + c.throttling(c.consumerThrottle) + + // dequeue buffer + c.bufferLock.Lock() + for c.buffer.Len() == 0 { + c.bufferCond.Wait() + } + it, ok := c.dequeueBuffer() + c.bufferLock.Unlock() + c.bufferCond.Signal() + if !ok { + // in fact, this case will never happen + continue + } + + // graceful shutdown + if it.value == terminalSig { + atomic.AddUint64(&c.consumed, 1) + close(c.consumer) + atomic.StoreInt32(&c.state, -2) + return + } + + // check expired + if it.IsExpired() { + if c.timeoutCallback != nil { + c.timeoutCallback(it.value) + } + atomic.AddUint64(&c.consumed, 1) + continue + } + // consuming, if block here means consumer is busy + c.consumer <- it.value + atomic.AddUint64(&c.consumed, 1) + } +} + +func (c *channel) throttling(throttle Throttle) { + if throttle == nil { + return + } + throttled := throttle(c) + if !throttled { + return + } + ticker := time.NewTicker(c.throttleWindow) + defer ticker.Stop() + + for throttled && !c.isClosed() { + <-ticker.C + throttled = throttle(c) + } +} + +func (c *channel) enqueueBuffer(it item) { + c.buffer.PushBack(it) +} + +func (c *channel) dequeueBuffer() (it item, ok bool) { + bi := c.buffer.Front() + if bi == nil { + return it, false + } + c.buffer.Remove(bi) + + it = bi.Value.(item) + return it, true +} diff --git a/lang/channel/channel_example_test.go b/lang/channel/channel_example_test.go new file mode 100644 index 00000000..93f33a4d --- /dev/null +++ b/lang/channel/channel_example_test.go @@ -0,0 +1,135 @@ +// Copyright 2023 ByteDance Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package channel + +import ( + "runtime" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +type request struct { + Id int + Latency time.Duration + Done chan struct{} +} + +type response struct { + Id int +} + +var taskPool Channel + +func Service1(req *request) { + taskPool.Input() <- req // async run + return +} + +func Service2(req *request) (*response, error) { + if req.Latency > 0 { + time.Sleep(req.Latency) + } + return &response{Id: req.Id}, nil +} + +func TestNetworkIsolationOrDownstreamBlock(t *testing.T) { + taskPool = New( + WithNonBlock(), + WithTimeout(time.Millisecond*10), + ) + var responded int32 + go func() { + // task worker + for task := range taskPool.Output() { + req := task.(*request) + done := make(chan struct{}) + go func() { + _, _ = Service2(req) + close(done) + }() + select { + case <-time.After(time.Millisecond * 100): + case <-done: + atomic.AddInt32(&responded, 1) + } + } + }() + + start := time.Now() + for i := 1; i <= 100; i++ { + req := &request{Id: i} + if i > 50 && i <= 60 { // suddenly have network issue for 10 requests + req.Latency = time.Hour + } + Service1(req) + } + cost := time.Now().Sub(start) + assert.True(t, cost < time.Millisecond*10) // Service1 should not block + time.Sleep(time.Millisecond * 1500) // wait all tasks finished + assert.Equal(t, int32(50), responded) // 50 success and 10 timeout and 40 discard +} + +func TestCPUHeavy(t *testing.T) { + runtime.GOMAXPROCS(1) + var concurrency int32 + taskPool = New( + WithNonBlock(), + WithThrottle(nil, func(c *channel) bool { + return atomic.LoadInt32(&concurrency) > 10 + }), + ) + var responded int32 + go func() { + // task worker + for task := range taskPool.Output() { + req := task.(*request) + t.Logf("NumGoroutine: %d", runtime.NumGoroutine()) + go func() { + curConcurrency := atomic.AddInt32(&concurrency, 1) + defer atomic.AddInt32(&concurrency, -1) + if curConcurrency > 10 { + // concurrency too high, reuqest faild + return + } + + atomic.AddInt32(&responded, 1) + if req.Id >= 11 && req.Id <= 20 { + start := time.Now() + for x := uint64(0); ; x++ { + if x%1000 == 0 { + if time.Now().Sub(start) >= 100*time.Millisecond { + return + } + } + } + } + }() + } + }() + + start := time.Now() + for i := 1; i <= 100; i++ { + req := &request{Id: i} + Service1(req) + } + cost := time.Now().Sub(start) + assert.True(t, cost < time.Millisecond*10) // Service1 should not block + time.Sleep(time.Second * 2) // wait all tasks finished + t.Logf("responded: %d", responded) + assert.True(t, int32(50) < responded) // most tasks success +} diff --git a/lang/channel/channel_test.go b/lang/channel/channel_test.go new file mode 100644 index 00000000..54700ca9 --- /dev/null +++ b/lang/channel/channel_test.go @@ -0,0 +1,449 @@ +// Copyright 2023 ByteDance Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package channel + +import ( + "fmt" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func tlogf(t *testing.T, format string, args ...interface{}) { + t.Log(fmt.Sprintf("[%v] %s", time.Now().UTC(), fmt.Sprintf(format, args...))) +} + +//go:noinline +func factorial(x int) int { + if x <= 1 { + return x + } + return x * factorial(x-1) +} + +var benchSizes = []int{1, 10, 100, 1000, -1} + +func BenchmarkNativeChan(b *testing.B) { + for _, size := range benchSizes { + if size < 0 { + continue + } + b.Run(fmt.Sprintf("Size-%d", size), func(b *testing.B) { + ch := make(chan interface{}, size) + b.RunParallel(func(pb *testing.PB) { + n := 0 + for pb.Next() { + n++ + ch <- n + <-ch + } + }) + }) + } +} + +func BenchmarkChannel(b *testing.B) { + for _, size := range benchSizes { + b.Run(fmt.Sprintf("Size-%d", size), func(b *testing.B) { + var ch Channel + if size < 0 { + ch = New(WithNonBlock()) + } else { + ch = New(WithSize(size)) + } + var producer, consumer chan interface{} = ch.Input(), ch.Output() + b.RunParallel(func(pb *testing.PB) { + n := 0 + for pb.Next() { + n++ + producer <- n + <-consumer + } + }) + }) + } +} + +func TestChannelDefaultSize(t *testing.T) { + channel := New() + defer channel.Close() + + producer, consumer := channel.Input(), channel.Output() + for i := 1; i <= 10; i++ { + producer <- i + t.Logf("put %d", i) + x := <-consumer + t.Logf("get %d", x) + assert.Equal(t, i, x) + } + producer <- 0 // wait for be consumed + producer <- 0 // wait for be buffered + timeout := false + select { + case producer <- 0: // block + case <-time.After(time.Millisecond * 10): + timeout = true + } + assert.True(t, timeout) +} + +func TestChannelClose(t *testing.T) { + beginGs := runtime.NumGoroutine() + channel := New() + afterGs := runtime.NumGoroutine() + assert.Equal(t, 2, afterGs-beginGs) + var exit int32 + go func() { + for _ = range channel.Output() { + } + atomic.AddInt32(&exit, 1) + }() + for i := 1; i <= 20; i++ { + channel.Input() <- i + } + channel.Close() + for runtime.NumGoroutine() > beginGs { + runtime.Gosched() + } + <-channel.Output() // never block + assert.Equal(t, int32(1), atomic.LoadInt32(&exit)) +} + +func TestChannelGCClose(t *testing.T) { + // close implicitly + go func() { + _ = New() + }() + go func() { + ch := New() + ch.Input() <- 1 + _ = <-ch.Output() + tlogf(t, "channel finished") + }() + for i := 0; i < 3; i++ { + time.Sleep(time.Millisecond * 10) + runtime.GC() + } + + // close explicitly + go func() { + ch := New() + ch.Close() + }() + for i := 0; i < 3; i++ { + time.Sleep(time.Millisecond * 10) + runtime.GC() + } +} + +func TestChannelTimeout(t *testing.T) { + channel := New( + WithTimeout(time.Millisecond*50), + WithSize(1024), + ) + producer, consumer := channel.Input(), channel.Output() + go func() { + for i := 1; i <= 20; i++ { + producer <- i + } + }() + var total int32 + go func() { + for c := range consumer { + id := c.(int) + if id >= 10 { + time.Sleep(time.Millisecond * 100) + } + atomic.AddInt32(&total, 1) + } + }() + time.Sleep(time.Second) + // success task: id in [1, 11] + // note that task with id=11 also will be consumed since it already checked. + assert.Equal(t, int32(11), atomic.LoadInt32(&total)) +} + +func TestChannelConsumerInflightLimit(t *testing.T) { + var inflight int32 + var limit int32 = 10 + var total = 20 + channel := New( + WithThrottle(nil, func(c *channel) bool { + return atomic.LoadInt32(&inflight) >= limit + }), + ) + producer, consumer := channel.Input(), channel.Output() + var wg sync.WaitGroup + go func() { + for c := range consumer { + atomic.AddInt32(&inflight, 1) + id := c.(int) + tlogf(t, "consumer=%d started", id) + go func() { + defer atomic.AddInt32(&inflight, -1) + defer wg.Done() + time.Sleep(time.Second) + tlogf(t, "consumer=%d finished", id) + }() + } + }() + + now := time.Now() + for i := 1; i <= total; i++ { + wg.Add(1) + id := i + producer <- id + tlogf(t, "producer=%d finished", id) + time.Sleep(time.Millisecond * 10) + } + wg.Wait() + duration := time.Now().Sub(now) + assert.Equal(t, 2, int(duration.Seconds())) +} + +func TestChannelProducerSpeedLimit(t *testing.T) { + var total = 15 + channel := New(WithSize(0)) + producer, consumer := channel.Input(), channel.Output() + go func() { + for c := range consumer { + id := c.(int) + time.Sleep(time.Millisecond * 100) + tlogf(t, "consumer=%d finished", id) + } + }() + + now := time.Now() + for i := 1; i <= total; i++ { + id := i + producer <- id + tlogf(t, "producer=%d finished", id) + } + duration := time.Now().Sub(now) + assert.Equal(t, 1, int(duration.Seconds())) +} + +func TestChannelProducerNoLimit(t *testing.T) { + var total = 100 + channel := New(WithSize(1000)) + producer, consumer := channel.Input(), channel.Output() + go func() { + for c := range consumer { + id := c.(int) + time.Sleep(time.Millisecond * 100) + tlogf(t, "consumer=%d finished", id) + } + }() + + now := time.Now() + for i := 1; i <= total; i++ { + id := i + producer <- id + } + duration := time.Now().Sub(now) + assert.Equal(t, 0, int(duration.Seconds())) +} + +func TestChannelGoroutinesThrottle(t *testing.T) { + goroutineChecker := func(maxGoroutines int) Throttle { + return func(c *channel) bool { + tlogf(t, "%d goroutines", runtime.NumGoroutine()) + return runtime.NumGoroutine() > maxGoroutines + } + } + var total = 1000 + throttle := goroutineChecker(100) + channel := New(WithThrottle(throttle, throttle), WithThrottleWindow(time.Millisecond*100)) + producer, consumer := channel.Input(), channel.Output() + var wg sync.WaitGroup + go func() { + for c := range consumer { + id := c.(int) + go func() { + time.Sleep(time.Millisecond * 100) + tlogf(t, "consumer=%d finished", id) + wg.Done() + }() + } + }() + + for i := 1; i <= total; i++ { + wg.Add(1) + id := i + producer <- id + tlogf(t, "producer=%d finished", id) + runtime.Gosched() + } + wg.Wait() +} + +func TestChannelNoConsumer(t *testing.T) { + channel := New() + producer, consumer := channel.Input(), channel.Output() + _ = consumer + var sum int32 + go func() { + for i := 1; i <= 20; i++ { + producer <- i + tlogf(t, "producer=%d finished", i) + atomic.AddInt32(&sum, 1) + } + }() + time.Sleep(time.Second) + assert.Equal(t, int32(2), atomic.LoadInt32(&sum)) +} + +func TestChannelOneSlowTask(t *testing.T) { + channel := New(WithTimeout(time.Millisecond*500), WithSize(0)) + producer, consumer := channel.Input(), channel.Output() + + var total int32 + go func() { + for c := range consumer { + id := c.(int) + if id == 10 { + time.Sleep(time.Second) + } + atomic.AddInt32(&total, 1) + } + }() + + for i := 1; i <= 20; i++ { + producer <- i + tlogf(t, "producer=%d finished", i) + } + time.Sleep(time.Second) + assert.Equal(t, int32(19), atomic.LoadInt32(&total)) +} + +func TestChannelProduceRateControl(t *testing.T) { + produceMaxRate := 100 + channel := New( + WithRateThrottle(produceMaxRate, 0), + ) + + go func() { + for c := range channel.Output() { + id := c.(int) + tlogf(t, "consumed: %d", id) + } + }() + begin := time.Now() + for i := 1; i <= 500; i++ { + channel.Input() <- i + } + cost := time.Now().Sub(begin) + tlogf(t, "Cost %dms", cost.Milliseconds()) +} + +func TestChannelConsumeRateControl(t *testing.T) { + channel := New( + WithRateThrottle(0, 100), + ) + + go func() { + for c := range channel.Output() { + id := c.(int) + tlogf(t, "consumed: %d", id) + } + }() + begin := time.Now() + for i := 1; i <= 500; i++ { + channel.Input() <- i + } + cost := time.Now().Sub(begin) + tlogf(t, "Cost %dms", cost.Milliseconds()) +} + +func TestChannelNonBlock(t *testing.T) { + channel := New(WithNonBlock()) + begin := time.Now() + for i := 1; i <= 10000; i++ { + channel.Input() <- i + tlogf(t, "producer=%d finished", i) + } + cost := time.Now().Sub(begin) + tlogf(t, "Cost %dms", cost.Milliseconds()) +} + +func TestAvoidGoroutineLeak(t *testing.T) { + // Default channel is safe + recvCh := New() + var wg sync.WaitGroup + wg.Add(1) + // producer + go func() { + time.Sleep(time.Millisecond * 100) // RPC Call + recvCh.Input() <- 1 + wg.Done() + }() + // consumer + select { + case <-recvCh.Output(): + case <-time.After(time.Millisecond * 50): + } + wg.Wait() // goroutine exit +} + +func TestFastRecoverConsumer(t *testing.T) { + var consumed int32 + var aborted int32 + timeout := time.Second * 1 + + channel := New( + WithNonBlock(), + WithTimeout(timeout), + WithTimeoutCallback(func(i interface{}) { + atomic.AddInt32(&aborted, 1) + }), + ) + defer channel.Close() + + // consumer + go func() { + for c := range channel.Output() { + id := c.(int) + t.Logf("consumed: %d", id) + time.Sleep(time.Millisecond * 100) + atomic.AddInt32(&consumed, 1) + } + }() + + // producer + // faster than consumer's ability + for i := 1; i <= 20; i++ { + channel.Input() <- i + time.Sleep(time.Millisecond * 10) + } + for (atomic.LoadInt32(&consumed) + atomic.LoadInt32(&aborted)) != 20 { + runtime.Gosched() + } + assert.True(t, aborted > 5) + consumed = 0 + aborted = 0 + // quick recover consumer + for i := 1; i <= 10; i++ { + channel.Input() <- i + time.Sleep(time.Millisecond * 10) + } + for atomic.LoadInt32(&consumed) != 10 { + runtime.Gosched() + } + // all consumed +}