From 41b6b68c9ce253b8c5d17a76d0ae7a5be7dba2cf Mon Sep 17 00:00:00 2001 From: Yihang Wang Date: Tue, 27 Feb 2024 11:40:56 +0800 Subject: [PATCH] test: add testing for Timeout, Submit and Sharding --- example/complex-http-crawler/main.go | 3 +- example/simple-http-crawler/main.go | 3 +- example/sleeper/main.go | 3 +- gojob.go | 42 +++++++-- gojob_test.go | 134 +++++++++++++++++++++++++++ pkg/util/capture.go | 44 +++++++++ 6 files changed, 216 insertions(+), 13 deletions(-) create mode 100644 gojob_test.go create mode 100644 pkg/util/capture.go diff --git a/example/complex-http-crawler/main.go b/example/complex-http-crawler/main.go index 3d169a9..cc0cf82 100644 --- a/example/complex-http-crawler/main.go +++ b/example/complex-http-crawler/main.go @@ -35,7 +35,8 @@ func main() { SetMaxRuntimePerTaskSeconds(opts.MaxRuntimePerTaskSeconds). SetNumShards(int64(opts.NumShards)). SetShard(int64(opts.Shard)). - SetOutputFilePath(opts.OutputFilePath) + SetOutputFilePath(opts.OutputFilePath). + Start() for line := range util.Cat(opts.InputFilePath) { scheduler.Submit(model.New(string(line))) diff --git a/example/simple-http-crawler/main.go b/example/simple-http-crawler/main.go index a12287e..7c88889 100644 --- a/example/simple-http-crawler/main.go +++ b/example/simple-http-crawler/main.go @@ -34,7 +34,8 @@ func main() { SetMaxRetries(4). SetMaxRuntimePerTaskSeconds(16). SetNumShards(4). - SetShard(0) + SetShard(0). + Start() for line := range util.Cat("input.txt") { scheduler.Submit(New(line)) } diff --git a/example/sleeper/main.go b/example/sleeper/main.go index 9cbb8d5..7d8424d 100644 --- a/example/sleeper/main.go +++ b/example/sleeper/main.go @@ -30,7 +30,8 @@ func main() { SetMaxRetries(4). SetMaxRuntimePerTaskSeconds(16). SetNumShards(4). - SetShard(0) + SetShard(0). + Start() scheduler.Start() for i := 0; i < 256; i++ { scheduler.Submit(New(i, rand.Intn(10))) diff --git a/gojob.go b/gojob.go index 3f07585..ffcdf62 100644 --- a/gojob.go +++ b/gojob.go @@ -3,6 +3,7 @@ package gojob import ( "context" "encoding/json" + "io" "log/slog" "os" "path/filepath" @@ -45,6 +46,7 @@ type Scheduler struct { MaxRuntimePerTaskSeconds int NumShards int64 Shard int64 + IsStarted bool NumTasks atomic.Int64 TaskChan chan *BasicTask LogChan chan string @@ -61,6 +63,7 @@ func NewScheduler() *Scheduler { MaxRuntimePerTaskSeconds: 16, NumShards: 3, Shard: 1, + IsStarted: false, NumTasks: atomic.Int64{}, TaskChan: make(chan *BasicTask), LogChan: make(chan string), @@ -123,6 +126,10 @@ func (s *Scheduler) SetMaxRuntimePerTaskSeconds(maxRuntimePerTaskSeconds int) *S // Submit submits a task to the scheduler func (s *Scheduler) Submit(task Task) { + if !s.IsStarted { + s.Start() + s.IsStarted = true + } index := s.NumTasks.Load() if (index % s.NumShards) == s.Shard { s.taskWg.Add(1) @@ -132,11 +139,15 @@ func (s *Scheduler) Submit(task Task) { } // Start starts the scheduler -func (s *Scheduler) Start() { +func (s *Scheduler) Start() *Scheduler { + if s.IsStarted { + return s + } for i := 0; i < s.NumWorkers; i++ { go s.Worker() } go s.Writer() + return s } // Wait waits for all tasks to finish @@ -182,30 +193,41 @@ func (s *Scheduler) Worker() { // Writer writes logs to file func (s *Scheduler) Writer() { - var fd *os.File + var fd io.Writer var err error - if s.OutputFilePath == "-" { + + switch s.OutputFilePath { + case "-": fd = os.Stdout - } else { + case "": + fd = io.Discard + default: // Create folder if not exists dir := filepath.Dir(s.OutputFilePath) if _, err := os.Stat(dir); os.IsNotExist(err) { - err = os.MkdirAll(dir, 0755) - if err != nil { - slog.Error("error occured while creating folder", slog.String("path", dir), slog.String("error", err.Error())) + if err := os.MkdirAll(dir, 0755); err != nil { + slog.Error("error occurred while creating folder", slog.String("path", dir), slog.String("error", err.Error())) return } } // Open file fd, err = os.OpenFile(s.OutputFilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) if err != nil { - slog.Error("error occured while opening file", slog.String("path", s.OutputFilePath), slog.String("error", err.Error())) + slog.Error("error occurred while opening file", slog.String("path", s.OutputFilePath), slog.String("error", err.Error())) return } - defer fd.Close() + defer func() { + if closeErr := fd.(*os.File).Close(); closeErr != nil { + slog.Error("error occurred while closing file", slog.String("path", s.OutputFilePath), slog.String("error", closeErr.Error())) + } + }() } + for result := range s.LogChan { - fd.WriteString(result + "\n") + if _, err := fd.Write([]byte(result + "\n")); err != nil { + slog.Error("error occurred while writing to file", slog.String("error", err.Error())) + continue + } s.logWg.Done() } } diff --git a/gojob_test.go b/gojob_test.go new file mode 100644 index 0000000..57599d5 --- /dev/null +++ b/gojob_test.go @@ -0,0 +1,134 @@ +package gojob_test + +import ( + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/WangYihang/gojob" +) + +type SafeWriter struct { + writer *strings.Builder + lock sync.Mutex +} + +func NewSafeWriter() *SafeWriter { + return &SafeWriter{ + writer: new(strings.Builder), + lock: sync.Mutex{}, + } +} + +func (sw *SafeWriter) WriteString(s string) { + sw.lock.Lock() + defer sw.lock.Unlock() + sw.writer.WriteString(s) +} + +func (sw *SafeWriter) String() string { + return sw.writer.String() +} + +type Task struct { + I int + writer *SafeWriter +} + +func NewTask(i int, writer *SafeWriter) *Task { + return &Task{ + I: i, + writer: writer, + } +} + +func (t *Task) Do() error { + t.writer.WriteString(fmt.Sprintf("%d\n", t.I)) + return nil +} + +func TestRunWithTimeout(t *testing.T) { + task := func() error { + time.Sleep(2 * time.Second) + return nil + } + + err := gojob.RunWithTimeout(task, 1*time.Second) + if err == nil { + t.Errorf("Expected timeout error, got nil") + } +} + +func TestSchedulerSubmit(t *testing.T) { + scheduler := gojob.NewScheduler().SetNumShards(2).SetShard(1) + safeWriter := NewSafeWriter() + task := NewTask(1, safeWriter) + scheduler.Submit(task) + if scheduler.NumTasks.Load() != 1 { + t.Errorf("Expected NumTasks to be 1, got %d", scheduler.NumTasks.Load()) + } +} + +func TestSharding(t *testing.T) { + testcases := []struct { + numShards int64 + shard int64 + expected []int + }{ + { + numShards: 2, + shard: 0, + expected: []int{0, 2, 4, 6, 8, 10, 12, 14}, + }, + { + numShards: 2, + shard: 1, + expected: []int{1, 3, 5, 7, 9, 11, 13, 15}, + }, + { + numShards: 3, + shard: 0, + expected: []int{0, 3, 6, 9, 12, 15}, + }, + { + numShards: 3, + shard: 1, + expected: []int{1, 4, 7, 10, 13}, + }, + { + numShards: 3, + shard: 2, + expected: []int{2, 5, 8, 11, 14}, + }, + } + for _, tc := range testcases { + safeWriter := NewSafeWriter() + scheduler := gojob.NewScheduler().SetNumShards(tc.numShards).SetShard(tc.shard).SetOutputFilePath("").Start() + for i := 0; i < 16; i++ { + scheduler.Submit(NewTask(i, safeWriter)) + } + scheduler.Wait() + output := safeWriter.String() + lines := strings.Split(output, "\n") + numbers := []int{} + for _, line := range lines { + if line == "" { + continue + } + number, err := strconv.Atoi(line) + if err != nil { + t.Fatal(err) + } + numbers = append(numbers, number) + } + sort.Ints(numbers) + if !reflect.DeepEqual(numbers, tc.expected) { + t.Errorf("Expected %v, got %v", tc.expected, numbers) + } + } +} diff --git a/pkg/util/capture.go b/pkg/util/capture.go new file mode 100644 index 0000000..688b803 --- /dev/null +++ b/pkg/util/capture.go @@ -0,0 +1,44 @@ +package util + +import ( + "bytes" + "io" + "os" +) + +type OutputCapture struct { + originalStdout *os.File + r *os.File + w *os.File + buffer bytes.Buffer +} + +func NewOutputCapture() *OutputCapture { + return &OutputCapture{} +} + +func (oc *OutputCapture) StartCapture() { + oc.originalStdout = os.Stdout + + r, w, _ := os.Pipe() + os.Stdout = w + + oc.r = r + oc.w = w +} + +func (oc *OutputCapture) StopCapture() { + os.Stdout = oc.originalStdout + oc.w.Close() + + io.Copy(&oc.buffer, oc.r) + oc.r.Close() +} + +func (oc *OutputCapture) GetCapturedOutput() string { + return oc.buffer.String() +} + +// capture := util.NewOutputCapture() +// capture.StartCapture() +// capture.StopCapture()