diff --git a/core/mongo/db.go b/core/mongo/db.go index 3cbcd084..79da5e36 100644 --- a/core/mongo/db.go +++ b/core/mongo/db.go @@ -1,10 +1,15 @@ package mongo import ( + "github.com/crawlab-team/crawlab/core/config" "github.com/spf13/viper" "go.mongodb.org/mongo-driver/mongo" ) +func init() { + config.InitConfig() +} + func GetMongoDb(dbName string) *mongo.Database { // Use default database name if not provided if dbName == "" { diff --git a/core/task/handler/runner_test.go b/core/task/handler/runner_test.go index 00ebf6e4..4c342863 100644 --- a/core/task/handler/runner_test.go +++ b/core/task/handler/runner_test.go @@ -21,7 +21,6 @@ import ( "github.com/crawlab-team/crawlab/core/constants" "github.com/crawlab-team/crawlab/core/models/models" "github.com/crawlab-team/crawlab/core/models/service" - "github.com/crawlab-team/crawlab/grpc" "github.com/spf13/viper" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -96,6 +95,60 @@ func setupRunner(t *testing.T) *Runner { return runner } +func setupPipe(runner *Runner) (pr *io.PipeReader, pw *io.PipeWriter) { + // Create a pipe for testing + pr, pw = io.Pipe() + runner.stdoutPipe = pr + runner.cmd.Stdout = pw + runner.cmd.Stderr = pw + return pr, pw +} + +func initRunner(runner *Runner) chan struct{} { + // Initialize context and other required fields + runner.ctx, runner.cancel = context.WithCancel(context.Background()) + runner.wg = sync.WaitGroup{} + runner.done = make(chan struct{}) + runner.ipcChan = make(chan entity.IPCMessage) + + // Create a channel to signal that the reader is ready + readerReady := make(chan struct{}) + + // Start IPC reader with ready signal + go func() { + defer runner.wg.Done() + runner.wg.Add(1) + close(readerReady) // Signal that reader is ready + + // Read directly from the pipe for debugging + scanner := bufio.NewScanner(runner.stdoutPipe) + for scanner.Scan() { + line := scanner.Text() + log.Infof("Read from pipe: %s", line) + + // Try to parse as IPC message + var ipcMsg entity.IPCMessage + if err := json.Unmarshal([]byte(line), &ipcMsg); err != nil { + log.Errorf("Failed to unmarshal IPC message: %v", err) + continue + } + + if ipcMsg.IPC { + log.Infof("Valid IPC message received: %+v", ipcMsg) + if runner.ipcHandler != nil { + runner.ipcHandler(ipcMsg) + } + } + } + + if err := scanner.Err(); err != nil { + log.Errorf("Scanner error: %v", err) + } + }() + + return readerReady +} + func TestRunner(t *testing.T) { // Setup test data setupGrpc(t) @@ -105,57 +158,9 @@ func TestRunner(t *testing.T) { runner := setupRunner(t) // Create a pipe for testing - pr, pw := io.Pipe() - defer func() { - _ = pr.Close() - log.Infof("closed reader pipe") - }() - defer func() { - _ = pw.Close() - log.Infof("closed writer pipe") - }() - runner.stdoutPipe = pr - - // Initialize context and other required fields - runner.ctx, runner.cancel = context.WithCancel(context.Background()) - runner.wg = sync.WaitGroup{} - runner.done = make(chan struct{}) - runner.ipcChan = make(chan entity.IPCMessage) - - // Create a channel to signal that the reader is ready - readerReady := make(chan struct{}) - - // Start IPC reader with ready signal - go func() { - defer runner.wg.Done() - runner.wg.Add(1) - close(readerReady) // Signal that reader is ready - - // Read directly from the pipe for debugging - scanner := bufio.NewScanner(pr) - for scanner.Scan() { - line := scanner.Text() - log.Infof("Read from pipe: %s", line) - - // Try to parse as IPC message - var ipcMsg entity.IPCMessage - if err := json.Unmarshal([]byte(line), &ipcMsg); err != nil { - log.Errorf("Failed to unmarshal IPC message: %v", err) - continue - } - - if ipcMsg.IPC { - log.Infof("Valid IPC message received: %+v", ipcMsg) - if runner.ipcHandler != nil { - runner.ipcHandler(ipcMsg) - } - } - } + _, pw := setupPipe(runner) - if err := scanner.Err(); err != nil { - log.Errorf("Scanner error: %v", err) - } - }() + readerReady := initRunner(runner) // Wait for reader to be ready <-readerReady @@ -168,7 +173,7 @@ func TestRunner(t *testing.T) { } // Create channels for synchronization - handled := make(chan bool) + processed := make(chan bool) messageError := make(chan error, 1) // Set up message handler @@ -182,7 +187,7 @@ func TestRunner(t *testing.T) { messageError <- fmt.Errorf("expected payload %v, got %v", testMsg.Payload, msg.Payload) return } - handled <- true + processed <- true }) // Convert message to JSON @@ -201,8 +206,8 @@ func TestRunner(t *testing.T) { // Wait for message handling with timeout select { - case <-handled: - log.Info("IPC message was handled successfully") + case <-processed: + log.Info("IPC message was processed successfully") case err := <-messageError: t.Fatalf("error handling message: %v", err) case <-time.After(5 * time.Second): @@ -218,9 +223,7 @@ func TestRunner(t *testing.T) { runner := setupRunner(t) // Create pipes for stdout - pr, pw := io.Pipe() - runner.cmd.Stdout = pw - runner.cmd.Stderr = pw + pr, _ := setupPipe(runner) // Start the command err := runner.cmd.Start() @@ -283,198 +286,156 @@ func TestRunner(t *testing.T) { // Create a runner runner := setupRunner(t) - // Create pipes for testing - pr, pw := io.Pipe() - defer pr.Close() - defer pw.Close() - runner.stdoutPipe = pr - - // Create a channel to signal that the reader is ready - readerReady := make(chan struct{}) + // Create a pipe for testing + _, pw := setupPipe(runner) - // Start IPC reader with ready signal - go func() { - close(readerReady) // Signal that reader is ready - runner.startIPCReader() - }() + readerReady := initRunner(runner) // Wait for reader to be ready <-readerReady // Test cases testCases := []struct { - name string - payload interface{} - expected int // expected number of records + name string + message entity.IPCMessage + expectError bool + errorTimeout bool }{ { - name: "single object", - payload: map[string]interface{}{ - "field1": "value1", - "field2": 123, - }, - expected: 1, - }, - { - name: "array of objects", - payload: []map[string]interface{}{ - { + name: "valid single object", + message: entity.IPCMessage{ + Type: constants.IPCMessageData, + Payload: map[string]interface{}{ "field1": "value1", "field2": 123, }, - { - "field1": "value2", - "field2": 456, - }, + IPC: true, }, - expected: 2, + expectError: false, }, { - name: "empty payload", - payload: nil, - expected: 0, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Create a channel to track processed messages - processed := make(chan int) - - // Mock the gRPC connection - runner.conn = &mockConnectClient{ - sendFunc: func(req *grpc.TaskServiceConnectRequest) error { - // Verify the request - assert.Equal(t, grpc.TaskServiceConnectCode_INSERT_DATA, req.Code) - assert.Equal(t, runner.tid.Hex(), req.TaskId) - - // If payload was nil, we expect no data - if tc.payload == nil { - processed <- 0 - return nil - } - - // Unmarshal the data to verify record count - var records []map[string]interface{} - err := json.Unmarshal(req.Data, &records) - assert.NoError(t, err) - processed <- len(records) - return nil + name: "valid array of objects", + message: entity.IPCMessage{ + Type: constants.IPCMessageData, + Payload: []map[string]interface{}{ + { + "field1": "value1", + "field2": 123, + }, + { + "field1": "value2", + "field2": 456, + }, }, - } - - // Create test message - testMsg := entity.IPCMessage{ - Type: constants.IPCMessageData, - Payload: tc.payload, - IPC: true, - } - - // Convert message to JSON and write to pipe - go func() { - jsonData, _ := json.Marshal(testMsg) - _, _ = fmt.Fprintln(pw, string(jsonData)) - }() - - // Wait for processing with timeout - select { - case recordCount := <-processed: - assert.Equal(t, tc.expected, recordCount) - case <-time.After(1 * time.Second): - if tc.expected > 0 { - t.Fatal("timeout waiting for IPC message to be processed") - } - } - }) - } - }) - - t.Run("HandleIPCInvalidData", func(t *testing.T) { - // Create a runner - runner := setupRunner(t) - - // Create pipes for testing - pr, pw := io.Pipe() - defer pr.Close() - defer pw.Close() - runner.stdoutPipe = pr - - // Create a channel to signal that the reader is ready - readerReady := make(chan struct{}) - - // Start IPC reader with ready signal - go func() { - close(readerReady) // Signal that reader is ready - runner.startIPCReader() - }() - - // Wait for reader to be ready - <-readerReady - - // Test cases for invalid data - testCases := []struct { - name string - message string // Raw message to send - }{ - { - name: "invalid json", - message: "{ invalid json", + IPC: true, + }, + expectError: false, }, { - name: "non-ipc json", - message: `{"type": "data", "payload": {"field": "value"}}`, // Missing IPC flag + name: "invalid payload type", + message: entity.IPCMessage{ + Type: constants.IPCMessageData, + Payload: "invalid", + IPC: true, + }, + expectError: true, + errorTimeout: true, }, { - name: "invalid payload type", - message: `{"type": "data", "payload": "invalid", "ipc": true}`, + name: "non-ipc message", + message: entity.IPCMessage{ + Type: constants.IPCMessageData, + Payload: map[string]interface{}{"field": "value"}, + IPC: false, + }, + expectError: true, + errorTimeout: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - // Create a channel to ensure no data is processed - processed := make(chan struct{}) - - // Mock the gRPC connection - runner.conn = &mockConnectClient{ - sendFunc: func(req *grpc.TaskServiceConnectRequest) error { - if req.Code == grpc.TaskServiceConnectCode_INSERT_DATA { - // This should not be called for invalid data - processed <- struct{}{} + // Create channels for synchronization + processed := make(chan bool) + messageError := make(chan error, 1) + + // Set up message handler + runner.SetIPCHandler(func(msg entity.IPCMessage) { + log.Infof("Handler received IPC message: %+v", msg) + + // Verify message type matches + if msg.Type != tc.message.Type { + messageError <- fmt.Errorf("expected message type %s, got %s", tc.message.Type, msg.Type) + return + } + + // Verify IPC flag + if msg.IPC != tc.message.IPC { + messageError <- fmt.Errorf("expected IPC flag %v, got %v", tc.message.IPC, msg.IPC) + return + } + + // For data messages, just verify the structure + if msg.Type == constants.IPCMessageData { + switch msg.Payload.(type) { + case map[string]interface{}, []map[string]interface{}, []interface{}: + processed <- true + default: + messageError <- fmt.Errorf("unexpected payload type: %T", msg.Payload) } - return nil - }, + return + } + + processed <- true + }) + + // Convert message to JSON + jsonData, err := json.Marshal(tc.message) + if err != nil { + t.Fatalf("failed to marshal test message: %v", err) } - // Write test message to pipe - go func() { - _, err := fmt.Fprintln(pw, tc.message) - if err != nil { - log.Errorf("failed to write to pipe: %v", err) + // Write message to pipe + log.Infof("Writing message to pipe: %s", string(jsonData)) + _, err = fmt.Fprintln(pw, string(jsonData)) + if err != nil { + t.Fatalf("failed to write to pipe: %v", err) + } + log.Info("Message written to pipe") + + if tc.expectError { + if tc.errorTimeout { + // For invalid messages, expect a timeout + select { + case <-processed: + t.Error("invalid message was unexpectedly processed") + case <-time.After(1 * time.Second): + // Success - no processing occurred + } + } else { + // For other error cases, expect an error message + select { + case err := <-messageError: + log.Infof("received expected error: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for error message") + } + } + } else { + // For valid messages, expect successful processing + select { + case <-processed: + log.Info("IPC message was processed successfully") + case err := <-messageError: + t.Fatalf("error handling message: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for IPC message to be handled") } - }() - - // Wait briefly to ensure no processing occurs - select { - case <-processed: - t.Error("invalid message was processed") - case <-time.After(1 * time.Second): - // Success - no processing occurred } }) } - }) -} -// mockConnectClient is a mock implementation of the gRPC Connect client -type mockConnectClient struct { - grpc.TaskService_ConnectClient - sendFunc func(*grpc.TaskServiceConnectRequest) error -} - -func (m *mockConnectClient) Send(req *grpc.TaskServiceConnectRequest) error { - if m.sendFunc != nil { - return m.sendFunc(req) - } - return nil + // Clean up + runner.cancel() // Cancel context to stop readers + }) }