diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5e57c1367..f484ab8a4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,8 +1,9 @@ name: CI on: [push, pull_request] env: - go-version: '1.17.x' + go-version: '1.17.5' postgis-version: '3.1' + redis-version: '3.2.4' jobs: test: name: Test @@ -17,7 +18,7 @@ jobs: - name: Install Redis uses: zhulik/redis-action@v1.0.0 with: - redis version: '5' + redis version: ${{ env.redis-version }} - name: Install PostgreSQL uses: nyaruka/postgis-action@v2 diff --git a/.gitignore b/.gitignore index 494792c01..f39bf7c00 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ _storage/ # Output of the go coverage tool, specifically when used with LiteIDE *.out +deploy fabric fabfile.py fabfile.pyc diff --git a/CHANGELOG.md b/CHANGELOG.md index bbfe70f60..333fd8c29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,182 @@ +v7.1.33 +---------- + * Set wait fields on sessions for dial waits as well + * Create completed sessions with wait_resume_on_expire = false + * Reduce exit sessions batch size to 100 + * Clear contact.current_flow_id when exiting sessions + +v7.1.32 +---------- + * Rework expirations to use ExitSessions + +v7.1.31 +---------- + * Consolidate how we interrupt sessions + * Tweak mailroom shutdown to only stop ES client if there is one + +v7.1.30 +---------- + * Remove deprecated fields on search endpoint + * Include flow reference when queuing messages + * Tweak coureier payload to not include unused fields + +v7.1.29 +---------- + * Update to latest goflow (fixes allowing bad URNs in start_session actions and adds @trigger.campaign) + * Commit modified_on changes outside of transaction + +v7.1.28 +---------- + * Include redis stats in analytics cron job + * Update wait_resume_on_expire on session writes + +v7.1.27 +---------- + * Always read run status instead of is_active + * Rename Session.TimeoutOn to WaitTimeoutOn + * Add flow_id to msg and record for flow messages + +v7.1.26 +---------- + * Add testdata functions for testing campaigns and events + * Use models.FireID consistently + * Replace broken redigo dep version and anything that was depending on it + * Simplify how we queue event fire tasks and improve logging + +v7.1.25 +---------- + * Update to latest gocommon + * Stop writing events on flow runs + +v7.1.24 +---------- + * Switch to dbutil package in gocommon + * Always exclude router arguments from PO file extraction + +v7.1.23 +---------- + * Session.CurrentFlowID whould be cleared when session exits + * Start writing FlowSession.wait_expires_on + * Update to latest goflow which removes activated waits + * Clamp flow expiration values to valid ranges when loading flows + +v7.1.22 +---------- + * Replace redisx package with new dependency + * Update test database to use big ids for flow run and session ids + * Move session storage mode to the runtime.Config instead of an org config value + +v7.1.21 +---------- + * Update to latest gocommon to get instagram scheme + +v7.1.20 +---------- + * Update to latest gocommon and goflow to get fix for random.IntN concurrency + +v7.1.19 +---------- + * Update to latest goflow + +v7.1.18 +---------- + * Fix not logging details of query errors + * CI with go 1.17.5 + +v7.1.17 +---------- + * Include segments in simulation responses + +v7.1.16 +---------- + * Record recent contacts for all segments + * Allow cron jobs to declare that they can run on all instances at same time - needed for analytics job + * Write failed messages when missing channel or URNs + * Switch to redisx.Locker for cron job locking + * Update goflow + * Rename redisx structs and remove legacy support from IntervalSet + +v7.1.15 +---------- + * Update goflow + * Use new key format with redisx.Marker but also use legacy key format for backwards compatibility + +v7.1.14 +---------- + * Update to latest goflow + * Add failed_reason to msg and set when failing messages due to looping or org suspension + * Simplify cron functions by not passing lock name and value which aren't used + * Stop writing msgs_msg.connection_id + * Stop writing msgs_msg.response_to + +v7.1.13 +---------- + * Replace trackers with series to determine unhealthy webhooks + * Correct use of KEYS vs ARGV in redisx scripts + * Rework how we create outgoing messages, and fix retries of high priority messages + +v7.1.12 +---------- + * Move msg level loop detection from courier + +v7.1.11 +---------- + * Add imports for missing task packages + +v7.1.10 +---------- + * Add redisx.Cacher util + +v7.1.9 +---------- + * Don't include response_to_id in courier payloads + * Add logging for ending webhook incidents + +v7.1.8 +---------- + * Update sessions and runs in batches when exiting + +v7.1.7 +---------- + * Fix handling of add label actions after msg resumes in IVR flows + * Add cron job to end webhook incidents when nodes are no longer unhealthy + * Re-add new locker code but this time don't let locking code hold redis connections for any length of time + * Create incident once org has had unhealthy webhooks for 20 minutes + +v7.1.6 +---------- + * Revert "Rework locker code for reusablity" + +v7.1.5 +---------- + * Pin to go 1.17.2 + +v7.1.4 +---------- + * Rework redis marker and locker code for reusablity + * Test with Redis 3.2.4 + * Add util class to track the state of something in redis over a recent time period + * Remove unneeded check for RP's celery task to retry messages + +v7.1.3 +---------- + * Add logging to msg retry task + +v7.1.2 +---------- + * Add task to retry errored messages + +v7.1.1 +---------- + * Remove notification.channel_id + +v7.1.0 +---------- + * Update to latest goflow with expression changes + * Make LUA script to queue messages to courier easier to understand + * Explicitly exclude msg fields from marshalling that courier doesn't use + * Remove unused code for looking up msgs by UUID + v7.0.1 ---------- * Update to latest goflow diff --git a/cmd/mailroom/main.go b/cmd/mailroom/main.go index 76817321d..cc642aed7 100644 --- a/cmd/mailroom/main.go +++ b/cmd/mailroom/main.go @@ -14,15 +14,17 @@ import ( _ "github.com/nyaruka/mailroom/core/handlers" _ "github.com/nyaruka/mailroom/core/hooks" - _ "github.com/nyaruka/mailroom/core/tasks/broadcasts" + _ "github.com/nyaruka/mailroom/core/tasks/analytics" _ "github.com/nyaruka/mailroom/core/tasks/campaigns" _ "github.com/nyaruka/mailroom/core/tasks/contacts" _ "github.com/nyaruka/mailroom/core/tasks/expirations" + _ "github.com/nyaruka/mailroom/core/tasks/handler" + _ "github.com/nyaruka/mailroom/core/tasks/incidents" _ "github.com/nyaruka/mailroom/core/tasks/interrupts" _ "github.com/nyaruka/mailroom/core/tasks/ivr" + _ "github.com/nyaruka/mailroom/core/tasks/msgs" _ "github.com/nyaruka/mailroom/core/tasks/schedules" _ "github.com/nyaruka/mailroom/core/tasks/starts" - _ "github.com/nyaruka/mailroom/core/tasks/stats" _ "github.com/nyaruka/mailroom/core/tasks/timeouts" _ "github.com/nyaruka/mailroom/services/ivr/twiml" _ "github.com/nyaruka/mailroom/services/ivr/vonage" diff --git a/core/handlers/base_test.go b/core/handlers/base_test.go index 8899ca433..754f524ae 100644 --- a/core/handlers/base_test.go +++ b/core/handlers/base_test.go @@ -5,8 +5,8 @@ import ( "encoding/json" "fmt" "testing" - "time" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/uuids" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/envs" @@ -17,7 +17,6 @@ import ( "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/core/runner" "github.com/nyaruka/mailroom/runtime" - "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" "github.com/gomodule/redigo/redis" @@ -253,11 +252,9 @@ func RunTestCases(t *testing.T, ctx context.Context, rt *runtime.Runtime, tcs [] err = tx.Commit() assert.NoError(t, err) - time.Sleep(500 * time.Millisecond) - // now check our assertions for j, a := range tc.SQLAssertions { - testsuite.AssertQuery(t, rt.DB, a.SQL, a.Args...).Returns(a.Count, "%d:%d: mismatch in expected count for query: %s", i, j, a.SQL) + assertdb.Query(t, rt.DB, a.SQL, a.Args...).Returns(a.Count, "%d:%d: mismatch in expected count for query: %s", i, j, a.SQL) } for j, a := range tc.Assertions { @@ -266,3 +263,32 @@ func RunTestCases(t *testing.T, ctx context.Context, rt *runtime.Runtime, tcs [] } } } + +func RunFlowAndApplyEvents(t *testing.T, ctx context.Context, rt *runtime.Runtime, env envs.Environment, eng flows.Engine, oa *models.OrgAssets, flowRef *assets.FlowReference, contact *flows.Contact) { + trigger := triggers.NewBuilder(env, flowRef, contact).Manual().Build() + fs, sprint, err := eng.NewSession(oa.SessionAssets(), trigger) + require.NoError(t, err) + + tx, err := rt.DB.BeginTxx(ctx, nil) + require.NoError(t, err) + + session, err := models.NewSession(ctx, tx, oa, fs, sprint) + require.NoError(t, err) + + err = tx.Commit() + require.NoError(t, err) + + scene := models.NewSceneForSession(session) + + tx, err = rt.DB.BeginTxx(ctx, nil) + require.NoError(t, err) + + err = models.HandleEvents(ctx, rt, tx, oa, scene, sprint.Events()) + require.NoError(t, err) + + err = models.ApplyEventPreCommitHooks(ctx, rt, tx, oa, []*models.Scene{scene}) + require.NoError(t, err) + + err = tx.Commit() + require.NoError(t, err) +} diff --git a/core/handlers/campaigns_test.go b/core/handlers/campaigns_test.go index c706d7978..0106f9921 100644 --- a/core/handlers/campaigns_test.go +++ b/core/handlers/campaigns_test.go @@ -21,18 +21,10 @@ func TestCampaigns(t *testing.T) { joined := assets.NewFieldReference("joined", "Joined") // insert an event on our campaign that is based on created_on - db.MustExec( - `INSERT INTO campaigns_campaignevent(is_active, created_on, modified_on, uuid, "offset", unit, event_type, delivery_hour, - campaign_id, created_by_id, modified_by_id, flow_id, relative_to_id, start_mode) - VALUES(TRUE, NOW(), NOW(), $1, 1000, 'W', 'F', -1, $2, 1, 1, $3, $4, 'I')`, - uuids.New(), testdata.RemindersCampaign.ID, testdata.Favorites.ID, testdata.CreatedOnField.ID) + testdata.InsertCampaignFlowEvent(db, testdata.RemindersCampaign, testdata.Favorites, testdata.CreatedOnField, 1000, "W") // insert an event on our campaign that is based on last_seen_on - db.MustExec( - `INSERT INTO campaigns_campaignevent(is_active, created_on, modified_on, uuid, "offset", unit, event_type, delivery_hour, - campaign_id, created_by_id, modified_by_id, flow_id, relative_to_id, start_mode) - VALUES(TRUE, NOW(), NOW(), $1, 2, 'D', 'F', -1, $2, 1, 1, $3, $4, 'I')`, - uuids.New(), testdata.RemindersCampaign.ID, testdata.Favorites.ID, testdata.LastSeenOnField.ID) + testdata.InsertCampaignFlowEvent(db, testdata.RemindersCampaign, testdata.Favorites, testdata.LastSeenOnField, 2, "D") // init their values db.MustExec( diff --git a/core/handlers/contact_field_changed.go b/core/handlers/contact_field_changed.go index 359ae1681..50f368a80 100644 --- a/core/handlers/contact_field_changed.go +++ b/core/handlers/contact_field_changed.go @@ -27,9 +27,9 @@ func handleContactFieldChanged(ctx context.Context, rt *runtime.Runtime, tx *sql "value": event.Value, }).Debug("contact field changed") - // add our callback scene.AppendToEventPreCommitHook(hooks.CommitFieldChangesHook, event) scene.AppendToEventPreCommitHook(hooks.UpdateCampaignEventsHook, event) + scene.AppendToEventPostCommitHook(hooks.ContactModifiedHook, event) return nil } diff --git a/core/handlers/contact_groups_changed.go b/core/handlers/contact_groups_changed.go index 62ed66109..8c3795e9d 100644 --- a/core/handlers/contact_groups_changed.go +++ b/core/handlers/contact_groups_changed.go @@ -47,7 +47,7 @@ func handleContactGroupsChanged(ctx context.Context, rt *runtime.Runtime, tx *sq // add our add event scene.AppendToEventPreCommitHook(hooks.CommitGroupChangesHook, hookEvent) scene.AppendToEventPreCommitHook(hooks.UpdateCampaignEventsHook, hookEvent) - scene.AppendToEventPreCommitHook(hooks.ContactModifiedHook, scene.ContactID()) + scene.AppendToEventPostCommitHook(hooks.ContactModifiedHook, event) } // add each of our groups @@ -70,7 +70,7 @@ func handleContactGroupsChanged(ctx context.Context, rt *runtime.Runtime, tx *sq scene.AppendToEventPreCommitHook(hooks.CommitGroupChangesHook, hookEvent) scene.AppendToEventPreCommitHook(hooks.UpdateCampaignEventsHook, hookEvent) - scene.AppendToEventPreCommitHook(hooks.ContactModifiedHook, scene.ContactID()) + scene.AppendToEventPostCommitHook(hooks.ContactModifiedHook, event) } return nil diff --git a/core/handlers/contact_language_changed.go b/core/handlers/contact_language_changed.go index ed19a6729..e64cd861f 100644 --- a/core/handlers/contact_language_changed.go +++ b/core/handlers/contact_language_changed.go @@ -27,5 +27,7 @@ func handleContactLanguageChanged(ctx context.Context, rt *runtime.Runtime, tx * }).Debug("changing contact language") scene.AppendToEventPreCommitHook(hooks.CommitLanguageChangesHook, event) + scene.AppendToEventPostCommitHook(hooks.ContactModifiedHook, event) + return nil } diff --git a/core/handlers/contact_name_changed.go b/core/handlers/contact_name_changed.go index 9f025fe87..5ff7bab1a 100644 --- a/core/handlers/contact_name_changed.go +++ b/core/handlers/contact_name_changed.go @@ -27,5 +27,7 @@ func handleContactNameChanged(ctx context.Context, rt *runtime.Runtime, tx *sqlx }).Debug("changing contact name") scene.AppendToEventPreCommitHook(hooks.CommitNameChangesHook, event) + scene.AppendToEventPostCommitHook(hooks.ContactModifiedHook, event) + return nil } diff --git a/core/handlers/contact_status_changed.go b/core/handlers/contact_status_changed.go index 391b8291a..90f4463ec 100644 --- a/core/handlers/contact_status_changed.go +++ b/core/handlers/contact_status_changed.go @@ -27,5 +27,7 @@ func handleContactStatusChanged(ctx context.Context, rt *runtime.Runtime, tx *sq }).Debug("updating contact status") scene.AppendToEventPreCommitHook(hooks.CommitStatusChangesHook, event) + scene.AppendToEventPostCommitHook(hooks.ContactModifiedHook, event) + return nil } diff --git a/core/handlers/contact_urns_changed.go b/core/handlers/contact_urns_changed.go index efc61f41d..44a3b8268 100644 --- a/core/handlers/contact_urns_changed.go +++ b/core/handlers/contact_urns_changed.go @@ -33,9 +33,8 @@ func handleContactURNsChanged(ctx context.Context, rt *runtime.Runtime, tx *sqlx URNs: event.URNs, } - // add our callback scene.AppendToEventPreCommitHook(hooks.CommitURNChangesHook, change) - scene.AppendToEventPreCommitHook(hooks.ContactModifiedHook, scene.ContactID()) + scene.AppendToEventPostCommitHook(hooks.ContactModifiedHook, event) return nil } diff --git a/core/handlers/input_labels_added.go b/core/handlers/input_labels_added.go index 7cc6c7f7b..55c4b692e 100644 --- a/core/handlers/input_labels_added.go +++ b/core/handlers/input_labels_added.go @@ -31,15 +31,10 @@ func handleInputLabelsAdded(ctx context.Context, rt *runtime.Runtime, tx *sqlx.T "labels": event.Labels, }).Debug("input labels added") - // in the case this session was started/resumed from a msg event, we have the msg ID cached on the session + // if the sprint had input, then it was started by a msg event and we should have the message ID saved on the session inputMsgID := scene.Session().IncomingMsgID() - if inputMsgID == models.NilMsgID { - var err error - inputMsgID, err = models.GetMessageIDFromUUID(ctx, tx, flows.MsgUUID(event.InputUUID)) - if err != nil { - return errors.Wrap(err, "unable to find input message") - } + return errors.New("handling input labels added event in session without msg") } // for each label add an insertion diff --git a/core/handlers/msg_created.go b/core/handlers/msg_created.go index 42a25b582..46e4cceaa 100644 --- a/core/handlers/msg_created.go +++ b/core/handlers/msg_created.go @@ -71,14 +71,8 @@ func handleMsgCreated(ctx context.Context, rt *runtime.Runtime, tx *sqlx.Tx, oa "urn": event.Msg.URN(), }).Debug("msg created event") - // ignore events that don't have a channel or URN set - // TODO: maybe we should create these messages in a failed state? - if scene.Session().SessionType() == models.FlowTypeMessaging && (event.Msg.URN() == urns.NilURN || event.Msg.Channel() == nil) { - return nil - } - // messages in messaging flows must have urn id set on them, if not, go look it up - if scene.Session().SessionType() == models.FlowTypeMessaging { + if scene.Session().SessionType() == models.FlowTypeMessaging && event.Msg.URN() != urns.NilURN { urn := event.Msg.URN() if models.GetURNInt(urn, "id") == 0 { urn, err := models.GetOrCreateURN(ctx, tx, oa, scene.ContactID(), event.Msg.URN()) @@ -99,17 +93,14 @@ func handleMsgCreated(ctx context.Context, rt *runtime.Runtime, tx *sqlx.Tx, oa } } - msg, err := models.NewOutgoingMsg(rt.Config, oa.Org(), channel, scene.ContactID(), event.Msg, event.CreatedOn()) + run, _ := scene.Session().FindStep(e.StepUUID()) + flow, _ := oa.Flow(run.FlowReference().UUID) + + msg, err := models.NewOutgoingFlowMsg(rt, oa.Org(), channel, scene.Session(), flow.(*models.Flow), event.Msg, event.CreatedOn()) if err != nil { return errors.Wrapf(err, "error creating outgoing message to %s", event.Msg.URN()) } - // include some information about the session - msg.SetSession(scene.Session().ID(), scene.Session().Status()) - - // set our reply to as well (will be noop in cases when there is no incoming message) - msg.SetResponseTo(scene.Session().IncomingMsgID(), scene.Session().IncomingMsgExternalID()) - // register to have this message committed scene.AppendToEventPreCommitHook(hooks.CommitMessagesHook, msg) diff --git a/core/handlers/msg_created_test.go b/core/handlers/msg_created_test.go index c8fb69933..e1d198ffb 100644 --- a/core/handlers/msg_created_test.go +++ b/core/handlers/msg_created_test.go @@ -66,8 +66,8 @@ func TestMsgCreated(t *testing.T) { }, SQLAssertions: []handlers.SQLAssertion{ { - SQL: "SELECT COUNT(*) FROM msgs_msg WHERE text='Hello World' AND contact_id = $1 AND metadata = $2 AND response_to_id = $3 AND high_priority = TRUE", - Args: []interface{}{testdata.Cathy.ID, `{"quick_replies":["yes","no"]}`, msg1.ID()}, + SQL: "SELECT COUNT(*) FROM msgs_msg WHERE text='Hello World' AND contact_id = $1 AND metadata = $2 AND high_priority = TRUE", + Args: []interface{}{testdata.Cathy.ID, `{"quick_replies":["yes","no"]}`}, Count: 2, }, { @@ -76,9 +76,9 @@ func TestMsgCreated(t *testing.T) { Count: 1, }, { - SQL: "SELECT COUNT(*) FROM msgs_msg WHERE contact_id=$1;", + SQL: "SELECT COUNT(*) FROM msgs_msg WHERE contact_id=$1 AND STATUS = 'F' AND failed_reason = 'D';", Args: []interface{}{testdata.Bob.ID}, - Count: 0, + Count: 1, }, { SQL: "SELECT COUNT(*) FROM msgs_msg WHERE contact_id = $1 AND text = $2 AND metadata = $3 AND direction = 'O' AND status = 'Q' AND channel_id = $4", diff --git a/core/handlers/session_triggered_test.go b/core/handlers/session_triggered_test.go index 9ca2cfaa1..9fc05cb20 100644 --- a/core/handlers/session_triggered_test.go +++ b/core/handlers/session_triggered_test.go @@ -44,7 +44,7 @@ func TestSessionTriggered(t *testing.T) { { Actions: handlers.ContactActionMap{ testdata.Cathy: []flows.Action{ - actions.NewStartSession(handlers.NewActionUUID(), simpleFlow.FlowReference(), nil, []*flows.ContactReference{contactRef}, []*assets.GroupReference{groupRef}, nil, true), + actions.NewStartSession(handlers.NewActionUUID(), simpleFlow.Reference(), nil, []*flows.ContactReference{contactRef}, []*assets.GroupReference{groupRef}, nil, true), }, }, SQLAssertions: []handlers.SQLAssertion{ @@ -105,7 +105,7 @@ func TestQuerySessionTriggered(t *testing.T) { favoriteFlow, err := oa.FlowByID(testdata.Favorites.ID) assert.NoError(t, err) - sessionAction := actions.NewStartSession(handlers.NewActionUUID(), favoriteFlow.FlowReference(), nil, nil, nil, nil, true) + sessionAction := actions.NewStartSession(handlers.NewActionUUID(), favoriteFlow.Reference(), nil, nil, nil, nil, true) sessionAction.ContactQuery = "name ~ @contact.name" tcs := []handlers.TestCase{ diff --git a/core/handlers/testdata/webhook_flow.json b/core/handlers/testdata/webhook_flow.json new file mode 100644 index 000000000..81f7f5948 --- /dev/null +++ b/core/handlers/testdata/webhook_flow.json @@ -0,0 +1,65 @@ +{ + "uuid": "bc5d6b7b-3e18-4d7c-8279-50b460e74f7f", + "name": "Test", + "spec_version": "13.1.0", + "language": "eng", + "type": "messaging", + "nodes": [ + { + "uuid": "1bff8fe4-0714-433e-96a3-437405bf21cf", + "actions": [ + { + "uuid": "4e2ddf56-dd6e-435d-b688-92ae60dcb35c", + "headers": { + "Accept": "application/json" + }, + "type": "call_webhook", + "url": "http://example.com", + "body": "", + "method": "GET", + "result_name": "Result" + } + ], + "router": { + "type": "switch", + "operand": "@results.result.category", + "cases": [ + { + "uuid": "9b74c0bd-7c1a-482b-b219-fb2e1ec50e59", + "type": "has_only_text", + "arguments": [ + "Success" + ], + "category_uuid": "f40e9698-adc7-4528-9761-beb76bfa1801" + } + ], + "categories": [ + { + "uuid": "f40e9698-adc7-4528-9761-beb76bfa1801", + "name": "Success", + "exit_uuid": "e46b09ab-7a9a-4b4b-a9e5-c6a8bd130517" + }, + { + "uuid": "3f884912-d193-4d57-86a6-046477b9e568", + "name": "Failure", + "exit_uuid": "d1318a17-86cc-4733-8c91-1ffb49917cfd" + } + ], + "default_category_uuid": "3f884912-d193-4d57-86a6-046477b9e568" + }, + "exits": [ + { + "uuid": "e46b09ab-7a9a-4b4b-a9e5-c6a8bd130517", + "destination_uuid": null + }, + { + "uuid": "d1318a17-86cc-4733-8c91-1ffb49917cfd", + "destination_uuid": null + } + ] + } + ], + "revision": 24, + "expire_after_minutes": 10080, + "localization": {} +} \ No newline at end of file diff --git a/core/handlers/webhook_called.go b/core/handlers/webhook_called.go index 0e003615a..7f71f20c1 100644 --- a/core/handlers/webhook_called.go +++ b/core/handlers/webhook_called.go @@ -4,13 +4,12 @@ import ( "context" "time" + "github.com/jmoiron/sqlx" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/goflow/flows/events" "github.com/nyaruka/mailroom/core/hooks" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/runtime" - - "github.com/jmoiron/sqlx" "github.com/sirupsen/logrus" ) @@ -28,6 +27,7 @@ func handleWebhookCalled(ctx context.Context, rt *runtime.Runtime, tx *sqlx.Tx, "status": event.Status, "elapsed_ms": event.ElapsedMS, "resthook": event.Resthook, + "extraction": event.Extraction, }).Debug("webhook called") // if this was a resthook and the status was 410, that means we should remove it @@ -41,7 +41,7 @@ func handleWebhookCalled(ctx context.Context, rt *runtime.Runtime, tx *sqlx.Tx, scene.AppendToEventPreCommitHook(hooks.UnsubscribeResthookHook, unsub) } - run, _ := scene.Session().FindStep(e.StepUUID()) + run, step := scene.Session().FindStep(e.StepUUID()) flow, _ := oa.Flow(run.FlowReference().UUID) // create an HTTP log @@ -58,5 +58,8 @@ func handleWebhookCalled(ctx context.Context, rt *runtime.Runtime, tx *sqlx.Tx, scene.AppendToEventPreCommitHook(hooks.InsertHTTPLogsHook, httpLog) } + // pass node and response time to the hook that monitors webhook health + scene.AppendToEventPreCommitHook(hooks.MonitorWebhooks, &hooks.WebhookCall{NodeUUID: step.NodeUUID(), Event: event}) + return nil } diff --git a/core/handlers/webhook_called_test.go b/core/handlers/webhook_called_test.go index 5e10971f5..d6417204d 100644 --- a/core/handlers/webhook_called_test.go +++ b/core/handlers/webhook_called_test.go @@ -1,15 +1,28 @@ package handlers_test import ( + "fmt" + "net/http" + "os" "testing" + "time" - "github.com/nyaruka/mailroom/core/handlers" - "github.com/nyaruka/mailroom/testsuite" - "github.com/nyaruka/mailroom/testsuite/testdata" - + "github.com/nyaruka/gocommon/dates" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/httpx" + "github.com/nyaruka/goflow/assets" + "github.com/nyaruka/goflow/envs" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/goflow/flows/actions" + "github.com/nyaruka/goflow/flows/engine" + "github.com/nyaruka/mailroom/core/handlers" + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/testsuite" + "github.com/nyaruka/mailroom/testsuite/testdata" + "github.com/nyaruka/redisx" + "github.com/nyaruka/redisx/assertredis" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestWebhookCalled(t *testing.T) { @@ -83,3 +96,100 @@ func TestWebhookCalled(t *testing.T) { handlers.RunTestCases(t, ctx, rt, tcs) } + +// a webhook service which fakes slow responses +type failingWebhookService struct { + delay time.Duration +} + +func (s *failingWebhookService) Call(session flows.Session, request *http.Request) (*flows.WebhookCall, error) { + return &flows.WebhookCall{ + Trace: &httpx.Trace{ + Request: request, + RequestTrace: []byte(`GET http://rapidpro.io/`), + Response: nil, + ResponseTrace: nil, + StartTime: dates.Now(), + EndTime: dates.Now().Add(s.delay), + }, + }, nil +} + +func TestUnhealthyWebhookCalls(t *testing.T) { + ctx, rt, db, rp := testsuite.Get() + rc := rp.Get() + defer rc.Close() + + defer testsuite.Reset(testsuite.ResetData | testsuite.ResetRedis) + defer dates.SetNowSource(dates.DefaultNowSource) + + dates.SetNowSource(dates.NewSequentialNowSource(time.Date(2021, 11, 17, 7, 0, 0, 0, time.UTC))) + + flowDef, err := os.ReadFile("testdata/webhook_flow.json") + require.NoError(t, err) + + testdata.InsertFlow(db, testdata.Org1, flowDef) + + oa, err := models.GetOrgAssetsWithRefresh(ctx, rt, testdata.Org1.ID, models.RefreshFlows) + require.NoError(t, err) + + env := envs.NewBuilder().Build() + _, cathy := testdata.Cathy.Load(db, oa) + + // webhook service with a 2 second delay + svc := &failingWebhookService{delay: 2 * time.Second} + + eng := engine.NewBuilder().WithWebhookServiceFactory(func(flows.Session) (flows.WebhookService, error) { return svc, nil }).Build() + flowRef := assets.NewFlowReference("bc5d6b7b-3e18-4d7c-8279-50b460e74f7f", "Test") + + handlers.RunFlowAndApplyEvents(t, ctx, rt, env, eng, oa, flowRef, cathy) + handlers.RunFlowAndApplyEvents(t, ctx, rt, env, eng, oa, flowRef, cathy) + + healthySeries := redisx.NewIntervalSeries("webhooks:healthy", time.Minute*5, 4) + unhealthySeries := redisx.NewIntervalSeries("webhooks:unhealthy", time.Minute*5, 4) + + total, err := healthySeries.Total(rc, "1bff8fe4-0714-433e-96a3-437405bf21cf") + assert.NoError(t, err) + assert.Equal(t, int64(2), total) + + total, err = unhealthySeries.Total(rc, "1bff8fe4-0714-433e-96a3-437405bf21cf") + assert.NoError(t, err) + assert.Equal(t, int64(0), total) + + // change webhook service delay to 30 seconds and re-run flow 9 times + svc.delay = 30 * time.Second + for i := 0; i < 9; i++ { + handlers.RunFlowAndApplyEvents(t, ctx, rt, env, eng, oa, flowRef, cathy) + } + + // still no incident tho.. + total, _ = healthySeries.Total(rc, "1bff8fe4-0714-433e-96a3-437405bf21cf") + assert.Equal(t, int64(2), total) + total, _ = unhealthySeries.Total(rc, "1bff8fe4-0714-433e-96a3-437405bf21cf") + assert.Equal(t, int64(9), total) + + assertdb.Query(t, db, `SELECT count(*) FROM notifications_incident WHERE incident_type = 'webhooks:unhealthy'`).Returns(0) + + // however 1 more bad call means this node is considered unhealthy + handlers.RunFlowAndApplyEvents(t, ctx, rt, env, eng, oa, flowRef, cathy) + + total, _ = healthySeries.Total(rc, "1bff8fe4-0714-433e-96a3-437405bf21cf") + assert.Equal(t, int64(2), total) + total, _ = unhealthySeries.Total(rc, "1bff8fe4-0714-433e-96a3-437405bf21cf") + assert.Equal(t, int64(10), total) + + // and now we have an incident + assertdb.Query(t, db, `SELECT count(*) FROM notifications_incident WHERE incident_type = 'webhooks:unhealthy'`).Returns(1) + + var incidentID models.IncidentID + db.Get(&incidentID, `SELECT id FROM notifications_incident`) + + // and a record of the nodes + assertredis.SMembers(t, rp, fmt.Sprintf("incident:%d:nodes", incidentID), []string{"1bff8fe4-0714-433e-96a3-437405bf21cf"}) + + // another bad call won't create another incident.. + handlers.RunFlowAndApplyEvents(t, ctx, rt, env, eng, oa, flowRef, cathy) + + assertdb.Query(t, db, `SELECT count(*) FROM notifications_incident WHERE incident_type = 'webhooks:unhealthy'`).Returns(1) + assertredis.SMembers(t, rp, fmt.Sprintf("incident:%d:nodes", incidentID), []string{"1bff8fe4-0714-433e-96a3-437405bf21cf"}) +} diff --git a/core/hooks/commit_field_changes.go b/core/hooks/commit_field_changes.go index 7ccaef58f..8bb345fef 100644 --- a/core/hooks/commit_field_changes.go +++ b/core/hooks/commit_field_changes.go @@ -103,8 +103,7 @@ const updateContactFieldsSQL = ` UPDATE contacts_contact c SET - fields = COALESCE(fields,'{}'::jsonb) || r.updates::jsonb, - modified_on = NOW() + fields = COALESCE(fields,'{}'::jsonb) || r.updates::jsonb FROM ( VALUES(:contact_id, :updates) ) AS @@ -117,8 +116,7 @@ const deleteContactFieldsSQL = ` UPDATE contacts_contact c SET - fields = fields - r.field_uuid, - modified_on = NOW() + fields = fields - r.field_uuid FROM ( VALUES(:contact_id, :field_uuid) ) AS diff --git a/core/hooks/commit_language_changes.go b/core/hooks/commit_language_changes.go index 6a2d00814..bfd43b263 100644 --- a/core/hooks/commit_language_changes.go +++ b/core/hooks/commit_language_changes.go @@ -39,8 +39,7 @@ const updateContactLanguageSQL = ` UPDATE contacts_contact c SET - language = r.language, - modified_on = NOW() + language = r.language FROM ( VALUES(:id, :language) ) AS diff --git a/core/hooks/commit_name_changes.go b/core/hooks/commit_name_changes.go index 62adc10b9..d627f8c92 100644 --- a/core/hooks/commit_name_changes.go +++ b/core/hooks/commit_name_changes.go @@ -40,8 +40,7 @@ const updateContactNameSQL = ` UPDATE contacts_contact c SET - name = r.name, - modified_on = NOW() + name = r.name FROM ( VALUES(:id, :name) ) AS diff --git a/core/hooks/monitor_webhooks.go b/core/hooks/monitor_webhooks.go new file mode 100644 index 000000000..f63514a33 --- /dev/null +++ b/core/hooks/monitor_webhooks.go @@ -0,0 +1,61 @@ +package hooks + +import ( + "context" + + "github.com/jmoiron/sqlx" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/flows/events" + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/runtime" + "github.com/pkg/errors" +) + +type WebhookCall struct { + NodeUUID flows.NodeUUID + Event *events.WebhookCalledEvent +} + +var MonitorWebhooks models.EventCommitHook = &monitorWebhooks{} + +type monitorWebhooks struct{} + +func (h *monitorWebhooks) Apply(ctx context.Context, rt *runtime.Runtime, tx *sqlx.Tx, oa *models.OrgAssets, scenes map[*models.Scene][]interface{}) error { + // organize events by nodes + eventsByNode := make(map[flows.NodeUUID][]*events.WebhookCalledEvent) + for _, es := range scenes { + for _, e := range es { + wc := e.(*WebhookCall) + eventsByNode[wc.NodeUUID] = append(eventsByNode[wc.NodeUUID], wc.Event) + } + } + + unhealthyNodeUUIDs := make([]flows.NodeUUID, 0, 10) + + // record events against each node and determine if it's healthy + for nodeUUID, events := range eventsByNode { + node := &models.WebhookNode{UUID: nodeUUID} + if err := node.Record(rt, events); err != nil { + return errors.Wrap(err, "error recording events for webhook node") + } + + healthy, err := node.Healthy(rt) + if err != nil { + return errors.Wrap(err, "error getting health of webhook node") + } + + if !healthy { + unhealthyNodeUUIDs = append(unhealthyNodeUUIDs, nodeUUID) + } + } + + // if we have unhealthy nodes, ensure we have an incident + if len(unhealthyNodeUUIDs) > 0 { + _, err := models.IncidentWebhooksUnhealthy(ctx, tx, rt.RP, oa, unhealthyNodeUUIDs) + if err != nil { + return errors.Wrap(err, "error creating unhealthy webhooks incident") + } + } + + return nil +} diff --git a/core/ivr/ivr.go b/core/ivr/ivr.go index fe158ab26..4a48531dc 100644 --- a/core/ivr/ivr.go +++ b/core/ivr/ivr.go @@ -22,6 +22,7 @@ import ( "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/core/runner" "github.com/nyaruka/mailroom/runtime" + "github.com/nyaruka/null" "github.com/gomodule/redigo/redis" "github.com/jmoiron/sqlx" @@ -415,7 +416,7 @@ func ResumeIVRFlow( return errors.Wrapf(err, "error creating flow contact") } - session, err := models.ActiveSessionForContact(ctx, rt.DB, rt.SessionStorage, oa, models.FlowTypeVoice, contact) + session, err := models.FindWaitingSessionForContact(ctx, rt.DB, rt.SessionStorage, oa, models.FlowTypeVoice, contact) if err != nil { return errors.Wrapf(err, "error loading session for contact") } @@ -433,7 +434,7 @@ func ResumeIVRFlow( // check if connection has been marked as errored - it maybe have been updated by status callback if conn.Status() == models.ConnectionStatusErrored || conn.Status() == models.ConnectionStatusFailed { - err = models.ExitSessions(ctx, rt.DB, []models.SessionID{session.ID()}, models.ExitInterrupted, time.Now()) + err = models.ExitSessions(ctx, rt.DB, []models.SessionID{session.ID()}, models.SessionStatusInterrupted) if err != nil { logrus.WithError(err).Error("error interrupting session") } @@ -483,6 +484,7 @@ func ResumeIVRFlow( switch res := ivrResume.(type) { case InputResume: resume, svcErr, err = buildMsgResume(ctx, rt, svc, channel, contact, urn, conn, oa, r, res) + session.SetIncomingMsg(resume.(*resumes.MsgResume).Msg().ID(), null.NullString) case DialResume: resume, svcErr, err = buildDialResume(oa, contact, res) @@ -513,7 +515,7 @@ func ResumeIVRFlow( return errors.Wrapf(err, "error writing ivr response for resume") } } else { - err = models.ExitSessions(ctx, rt.DB, []models.SessionID{session.ID()}, models.ExitCompleted, time.Now()) + err = models.ExitSessions(ctx, rt.DB, []models.SessionID{session.ID()}, models.SessionStatusCompleted) if err != nil { logrus.WithError(err).Error("error closing session") } diff --git a/core/models/airtime_test.go b/core/models/airtime_test.go index 7794ac41d..aaaeea458 100644 --- a/core/models/airtime_test.go +++ b/core/models/airtime_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/testsuite" @@ -33,7 +34,7 @@ func TestAirtimeTransfers(t *testing.T) { err := models.InsertAirtimeTransfers(ctx, db, []*models.AirtimeTransfer{transfer}) assert.Nil(t, err) - testsuite.AssertQuery(t, db, `SELECT org_id, status from airtime_airtimetransfer`).Columns(map[string]interface{}{"org_id": int64(1), "status": "S"}) + assertdb.Query(t, db, `SELECT org_id, status from airtime_airtimetransfer`).Columns(map[string]interface{}{"org_id": int64(1), "status": "S"}) // insert a failed transfer with nil sender, empty currency transfer = models.NewAirtimeTransfer( @@ -50,5 +51,5 @@ func TestAirtimeTransfers(t *testing.T) { err = models.InsertAirtimeTransfers(ctx, db, []*models.AirtimeTransfer{transfer}) assert.Nil(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) from airtime_airtimetransfer WHERE org_id = $1 AND status = $2`, testdata.Org1.ID, models.AirtimeTransferStatusFailed).Returns(1) + assertdb.Query(t, db, `SELECT count(*) from airtime_airtimetransfer WHERE org_id = $1 AND status = $2`, testdata.Org1.ID, models.AirtimeTransferStatusFailed).Returns(1) } diff --git a/core/models/campaigns.go b/core/models/campaigns.go index 46c050da7..aa35fe94e 100644 --- a/core/models/campaigns.go +++ b/core/models/campaigns.go @@ -6,21 +6,20 @@ import ( "encoding/json" "time" + "github.com/jmoiron/sqlx" + "github.com/lib/pq" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/gocommon/uuids" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/mailroom/runtime" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/nyaruka/null" - - "github.com/jmoiron/sqlx" - "github.com/lib/pq" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) // FireID is our id for our event fires -type FireID int +type FireID int64 // CampaignID is our type for campaign ids type CampaignID int @@ -270,7 +269,7 @@ func loadCampaigns(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]*Campai campaigns := make([]*Campaign, 0, 2) for rows.Next() { campaign := &Campaign{} - err := dbutil.ReadJSONRow(rows, &campaign.c) + err := dbutil.ScanJSON(rows, &campaign.c) if err != nil { return nil, errors.Wrapf(err, "error unmarshalling campaign") } @@ -404,7 +403,7 @@ type EventFire struct { } // LoadEventFires loads all the event fires with the passed in ids -func LoadEventFires(ctx context.Context, db Queryer, ids []int64) ([]*EventFire, error) { +func LoadEventFires(ctx context.Context, db Queryer, ids []FireID) ([]*EventFire, error) { start := time.Now() q, vs, err := sqlx.In(loadEventFireSQL, ids) @@ -522,10 +521,10 @@ func AddEventFires(ctx context.Context, tx Queryer, adds []*FireAdd) error { // DeleteUnfiredEventsForGroupRemoval deletes any unfired events for all campaigns that are // based on the passed in group id for all the passed in contacts. -func DeleteUnfiredEventsForGroupRemoval(ctx context.Context, tx Queryer, org *OrgAssets, contactIDs []ContactID, groupID GroupID) error { +func DeleteUnfiredEventsForGroupRemoval(ctx context.Context, tx Queryer, oa *OrgAssets, contactIDs []ContactID, groupID GroupID) error { fds := make([]*FireDelete, 0, 10) - for _, c := range org.CampaignByGroupID(groupID) { + for _, c := range oa.CampaignByGroupID(groupID) { for _, e := range c.Events() { for _, cid := range contactIDs { fds = append(fds, &FireDelete{ @@ -542,14 +541,14 @@ func DeleteUnfiredEventsForGroupRemoval(ctx context.Context, tx Queryer, org *Or // AddCampaignEventsForGroupAddition first removes the passed in contacts from any events that group change may effect, then recreates // the campaign events they qualify for. -func AddCampaignEventsForGroupAddition(ctx context.Context, tx Queryer, org *OrgAssets, contacts []*flows.Contact, groupID GroupID) error { +func AddCampaignEventsForGroupAddition(ctx context.Context, tx Queryer, oa *OrgAssets, contacts []*flows.Contact, groupID GroupID) error { cids := make([]ContactID, len(contacts)) for i, c := range contacts { cids[i] = ContactID(c.ID()) } // first remove all unfired events that may be affected by our group change - err := DeleteUnfiredEventsForGroupRemoval(ctx, tx, org, cids, groupID) + err := DeleteUnfiredEventsForGroupRemoval(ctx, tx, oa, cids, groupID) if err != nil { return errors.Wrapf(err, "error removing unfired campaign events for contacts") } @@ -557,12 +556,12 @@ func AddCampaignEventsForGroupAddition(ctx context.Context, tx Queryer, org *Org // now calculate which event fires need to be added fas := make([]*FireAdd, 0, 10) - tz := org.Env().Timezone() + tz := oa.Env().Timezone() // for each of our contacts for _, contact := range contacts { // for each campaign that may have changed from this group change - for _, c := range org.CampaignByGroupID(groupID) { + for _, c := range oa.CampaignByGroupID(groupID) { // check each event for _, e := range c.Events() { // and if we qualify by field diff --git a/core/models/campaigns_test.go b/core/models/campaigns_test.go index 3781f74b7..7f35dccff 100644 --- a/core/models/campaigns_test.go +++ b/core/models/campaigns_test.go @@ -6,10 +6,10 @@ import ( "testing" "time" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -95,10 +95,10 @@ func TestAddEventFires(t *testing.T) { }) require.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM campaigns_eventfire`).Returns(3) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1 AND event_id = $2`, testdata.Cathy.ID, testdata.RemindersEvent1.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1 AND event_id = $2`, testdata.Bob.ID, testdata.RemindersEvent1.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1 AND event_id = $2`, testdata.Bob.ID, testdata.RemindersEvent2.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM campaigns_eventfire`).Returns(3) + assertdb.Query(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1 AND event_id = $2`, testdata.Cathy.ID, testdata.RemindersEvent1.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1 AND event_id = $2`, testdata.Bob.ID, testdata.RemindersEvent1.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1 AND event_id = $2`, testdata.Bob.ID, testdata.RemindersEvent2.ID).Returns(1) db.MustExec(`UPDATE campaigns_eventfire SET fired = NOW() WHERE contact_id = $1`, testdata.Cathy.ID) @@ -110,7 +110,7 @@ func TestAddEventFires(t *testing.T) { }) require.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM campaigns_eventfire`).Returns(4) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1 AND event_id = $2`, testdata.Cathy.ID, testdata.RemindersEvent1.ID).Returns(2) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1`, testdata.Bob.ID).Returns(2) + assertdb.Query(t, db, `SELECT count(*) FROM campaigns_eventfire`).Returns(4) + assertdb.Query(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1 AND event_id = $2`, testdata.Cathy.ID, testdata.RemindersEvent1.ID).Returns(2) + assertdb.Query(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1`, testdata.Bob.ID).Returns(2) } diff --git a/core/models/channel_connection_test.go b/core/models/channel_connection_test.go index 60e28a2ce..c8a150563 100644 --- a/core/models/channel_connection_test.go +++ b/core/models/channel_connection_test.go @@ -3,10 +3,10 @@ package models_test import ( "testing" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" - "github.com/stretchr/testify/assert" ) @@ -23,7 +23,7 @@ func TestChannelConnections(t *testing.T) { err = conn.UpdateExternalID(ctx, db, "test1") assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) from channels_channelconnection where external_id = 'test1' AND id = $1`, conn.ID()).Returns(1) + assertdb.Query(t, db, `SELECT count(*) from channels_channelconnection where external_id = 'test1' AND id = $1`, conn.ID()).Returns(1) conn2, err := models.SelectChannelConnection(ctx, db, conn.ID()) assert.NoError(t, err) diff --git a/core/models/channel_logs_test.go b/core/models/channel_logs_test.go index 71f2ba612..058fe1648 100644 --- a/core/models/channel_logs_test.go +++ b/core/models/channel_logs_test.go @@ -4,6 +4,7 @@ import ( "net/http" "testing" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/httpx" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/testsuite" @@ -48,8 +49,8 @@ func TestChannelLogs(t *testing.T) { err = models.InsertChannelLogs(ctx, db, []*models.ChannelLog{log1, log2, log3}) require.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM channels_channellog`).Returns(3) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM channels_channellog WHERE url = 'http://rapidpro.io' AND is_error = FALSE AND channel_id = $1`, channel.ID()).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM channels_channellog WHERE url = 'http://rapidpro.io/bad' AND is_error = TRUE AND channel_id = $1`, channel.ID()).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM channels_channellog WHERE url = 'https://rapidpro.io/old' AND is_error = FALSE AND channel_id = $1`, channel.ID()).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM channels_channellog`).Returns(3) + assertdb.Query(t, db, `SELECT count(*) FROM channels_channellog WHERE url = 'http://rapidpro.io' AND is_error = FALSE AND channel_id = $1`, channel.ID()).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM channels_channellog WHERE url = 'http://rapidpro.io/bad' AND is_error = TRUE AND channel_id = $1`, channel.ID()).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM channels_channellog WHERE url = 'https://rapidpro.io/old' AND is_error = FALSE AND channel_id = $1`, channel.ID()).Returns(1) } diff --git a/core/models/channels.go b/core/models/channels.go index 16ede2f29..1b9c99d58 100644 --- a/core/models/channels.go +++ b/core/models/channels.go @@ -7,12 +7,11 @@ import ( "math" "time" + "github.com/lib/pq" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/envs" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/nyaruka/null" - - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -124,11 +123,48 @@ func (c *Channel) ChannelReference() *assets.ChannelReference { return assets.NewChannelReference(c.UUID(), c.Name()) } +// GetChannelsByID fetches channels by ID - NOTE these are "lite" channels and only include fields for sending +func GetChannelsByID(ctx context.Context, db Queryer, ids []ChannelID) ([]*Channel, error) { + rows, err := db.QueryxContext(ctx, selectChannelsByIDSQL, pq.Array(ids)) + if err != nil { + return nil, errors.Wrapf(err, "error querying channels by id") + } + defer rows.Close() + + channels := make([]*Channel, 0, 5) + for rows.Next() { + channel := &Channel{} + err := dbutil.ScanJSON(rows, &channel.c) + if err != nil { + return nil, errors.Wrapf(err, "error unmarshalling channel") + } + + channels = append(channels, channel) + } + + return channels, nil +} + +const selectChannelsByIDSQL = ` +SELECT ROW_TO_JSON(r) FROM (SELECT + c.id as id, + c.uuid as uuid, + c.name as name, + c.channel_type as channel_type, + COALESCE(c.tps, 10) as tps, + COALESCE(c.config, '{}')::json as config +FROM + channels_channel c +WHERE + c.id = ANY($1) +) r; +` + // loadChannels loads all the channels for the passed in org -func loadChannels(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets.Channel, error) { +func loadChannels(ctx context.Context, db Queryer, orgID OrgID) ([]assets.Channel, error) { start := time.Now() - rows, err := db.Queryx(selectChannelsSQL, orgID) + rows, err := db.QueryxContext(ctx, selectChannelsSQL, orgID) if err != nil { return nil, errors.Wrapf(err, "error querying channels for org: %d", orgID) } @@ -137,7 +173,7 @@ func loadChannels(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets.C channels := make([]assets.Channel, 0, 2) for rows.Next() { channel := &Channel{} - err := dbutil.ReadJSONRow(rows, &channel.c) + err := dbutil.ScanJSON(rows, &channel.c) if err != nil { return nil, errors.Wrapf(err, "error unmarshalling channel") } diff --git a/core/models/classifiers.go b/core/models/classifiers.go index 9cf0336c7..f0e1da216 100644 --- a/core/models/classifiers.go +++ b/core/models/classifiers.go @@ -5,6 +5,8 @@ import ( "database/sql/driver" "time" + "github.com/jmoiron/sqlx" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/goflow/flows/engine" @@ -13,10 +15,7 @@ import ( "github.com/nyaruka/goflow/services/classification/wit" "github.com/nyaruka/mailroom/core/goflow" "github.com/nyaruka/mailroom/runtime" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/nyaruka/null" - - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -140,7 +139,7 @@ func loadClassifiers(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]asset classifiers := make([]assets.Classifier, 0, 2) for rows.Next() { classifier := &Classifier{} - err := dbutil.ReadJSONRow(rows, &classifier.c) + err := dbutil.ScanJSON(rows, &classifier.c) if err != nil { return nil, errors.Wrapf(err, "error unmarshalling classifier") } diff --git a/core/models/contacts.go b/core/models/contacts.go index 2657529ff..3be0d1e54 100644 --- a/core/models/contacts.go +++ b/core/models/contacts.go @@ -9,17 +9,16 @@ import ( "strconv" "time" + "github.com/lib/pq" "github.com/nyaruka/gocommon/dates" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/gocommon/uuids" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/envs" "github.com/nyaruka/goflow/excellent/types" "github.com/nyaruka/goflow/flows" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/nyaruka/null" - - "github.com/lib/pq" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -127,7 +126,7 @@ func (c *Contact) UpdateLastSeenOn(ctx context.Context, db Queryer, lastSeenOn t // UpdatePreferredURN updates the URNs for the contact (if needbe) to have the passed in URN as top priority // with the passed in channel as the preferred channel -func (c *Contact) UpdatePreferredURN(ctx context.Context, db Queryer, org *OrgAssets, urnID URNID, channel *Channel) error { +func (c *Contact) UpdatePreferredURN(ctx context.Context, db Queryer, oa *OrgAssets, urnID URNID, channel *Channel) error { // no urns? that's an error if len(c.urns) == 0 { return errors.Errorf("can't set preferred URN on contact with no URNs") @@ -135,7 +134,7 @@ func (c *Contact) UpdatePreferredURN(ctx context.Context, db Queryer, org *OrgAs // is this already our top URN? topURNID := URNID(GetURNInt(c.urns[0], "id")) - topChannelID := GetURNChannelID(org, c.urns[0]) + topChannelID := GetURNChannelID(oa, c.urns[0]) // we are already the top URN, nothing to do if topURNID == urnID && topChannelID != NilChannelID && channel != nil && topChannelID == channel.ID() { @@ -178,7 +177,7 @@ func (c *Contact) UpdatePreferredURN(ctx context.Context, db Queryer, org *OrgAs } // write our new state to the db - err := UpdateContactURNs(ctx, db, org, []*ContactURNsChanged{change}) + err := UpdateContactURNs(ctx, db, oa, []*ContactURNsChanged{change}) if err != nil { return errors.Wrapf(err, "error updating urns for contact") } @@ -234,8 +233,8 @@ func (c *Contact) FlowContact(oa *OrgAssets) (*flows.Contact, error) { } // LoadContact loads a contact from the passed in id -func LoadContact(ctx context.Context, db Queryer, org *OrgAssets, id ContactID) (*Contact, error) { - contacts, err := LoadContacts(ctx, db, org, []ContactID{id}) +func LoadContact(ctx context.Context, db Queryer, oa *OrgAssets, id ContactID) (*Contact, error) { + contacts, err := LoadContacts(ctx, db, oa, []ContactID{id}) if err != nil { return nil, err } @@ -247,10 +246,10 @@ func LoadContact(ctx context.Context, db Queryer, org *OrgAssets, id ContactID) // LoadContacts loads a set of contacts for the passed in ids. Note that the order of the returned contacts // won't necessarily match the order of the ids. -func LoadContacts(ctx context.Context, db Queryer, org *OrgAssets, ids []ContactID) ([]*Contact, error) { +func LoadContacts(ctx context.Context, db Queryer, oa *OrgAssets, ids []ContactID) ([]*Contact, error) { start := time.Now() - rows, err := db.QueryxContext(ctx, selectContactSQL, pq.Array(ids), org.OrgID()) + rows, err := db.QueryxContext(ctx, selectContactSQL, pq.Array(ids), oa.OrgID()) if err != nil { return nil, errors.Wrap(err, "error selecting contacts") } @@ -259,7 +258,7 @@ func LoadContacts(ctx context.Context, db Queryer, org *OrgAssets, ids []Contact contacts := make([]*Contact, 0, len(ids)) for rows.Next() { e := &contactEnvelope{} - err := dbutil.ReadJSONRow(rows, e) + err := dbutil.ScanJSON(rows, e) if err != nil { return nil, errors.Wrap(err, "error scanning contact json") } @@ -278,7 +277,7 @@ func LoadContacts(ctx context.Context, db Queryer, org *OrgAssets, ids []Contact // load our real groups groups := make([]*Group, 0, len(e.GroupIDs)) for _, g := range e.GroupIDs { - group := org.GroupByID(g) + group := oa.GroupByID(g) if group != nil { groups = append(groups, group) } @@ -287,7 +286,7 @@ func LoadContacts(ctx context.Context, db Queryer, org *OrgAssets, ids []Contact // create our map of field values filtered by what we know exists fields := make(map[string]*flows.Value) - orgFields, _ := org.Fields() + orgFields, _ := oa.Fields() for _, f := range orgFields { field := f.(*Field) cv, found := e.Fields[field.UUID()] @@ -308,9 +307,9 @@ func LoadContacts(ctx context.Context, db Queryer, org *OrgAssets, ids []Contact // finally build up our URN objects contactURNs := make([]urns.URN, 0, len(e.URNs)) for _, u := range e.URNs { - urn, err := u.AsURN(org) + urn, err := u.AsURN(oa) if err != nil { - logrus.WithField("urn", u).WithField("org_id", org.OrgID()).WithField("contact_id", contact.id).Warn("invalid URN, ignoring") + logrus.WithField("urn", u).WithField("org_id", oa.OrgID()).WithField("contact_id", contact.id).Warn("invalid URN, ignoring") continue } contactURNs = append(contactURNs, urn) @@ -320,9 +319,9 @@ func LoadContacts(ctx context.Context, db Queryer, org *OrgAssets, ids []Contact // initialize our tickets tickets := make([]*Ticket, 0, len(e.Tickets)) for _, t := range e.Tickets { - ticketer := org.TicketerByID(t.TicketerID) + ticketer := oa.TicketerByID(t.TicketerID) if ticketer != nil { - tickets = append(tickets, NewTicket(t.UUID, org.OrgID(), contact.ID(), ticketer.ID(), t.ExternalID, t.TopicID, t.Body, t.AssigneeID, nil)) + tickets = append(tickets, NewTicket(t.UUID, oa.OrgID(), contact.ID(), ticketer.ID(), t.ExternalID, t.TopicID, t.Body, t.AssigneeID, nil)) } } contact.tickets = tickets @@ -345,10 +344,10 @@ func LoadContactsByUUID(ctx context.Context, db Queryer, oa *OrgAssets, uuids [] } // GetNewestContactModifiedOn returns the newest modified_on for a contact in the passed in org -func GetNewestContactModifiedOn(ctx context.Context, db Queryer, org *OrgAssets) (*time.Time, error) { - rows, err := db.QueryxContext(ctx, "SELECT modified_on FROM contacts_contact WHERE org_id = $1 ORDER BY modified_on DESC LIMIT 1", org.OrgID()) +func GetNewestContactModifiedOn(ctx context.Context, db Queryer, oa *OrgAssets) (*time.Time, error) { + rows, err := db.QueryxContext(ctx, "SELECT modified_on FROM contacts_contact WHERE org_id = $1 ORDER BY modified_on DESC LIMIT 1", oa.OrgID()) if err != nil && err != sql.ErrNoRows { - return nil, errors.Wrapf(err, "error selecting most recently changed contact for org: %d", org.OrgID()) + return nil, errors.Wrapf(err, "error selecting most recently changed contact for org: %d", oa.OrgID()) } defer rows.Close() if err != sql.ErrNoRows { @@ -356,7 +355,7 @@ func GetNewestContactModifiedOn(ctx context.Context, db Queryer, org *OrgAssets) var newest time.Time err = rows.Scan(&newest) if err != nil { - return nil, errors.Wrapf(err, "error scanning most recent contact modified_on for org: %d", org.OrgID()) + return nil, errors.Wrapf(err, "error scanning most recent contact modified_on for org: %d", oa.OrgID()) } return &newest, nil @@ -417,11 +416,11 @@ type ContactURN struct { } // AsURN returns a full URN representation including the query parameters needed by goflow and mailroom -func (u *ContactURN) AsURN(org *OrgAssets) (urns.URN, error) { +func (u *ContactURN) AsURN(oa *OrgAssets) (urns.URN, error) { // load any channel if present var channel *Channel if u.ChannelID != ChannelID(0) { - channel = org.ChannelByID(u.ChannelID) + channel = oa.ChannelByID(u.ChannelID) } // we build our query from a combination of preferred channel and auth @@ -830,11 +829,11 @@ func insertContactAndURNs(ctx context.Context, db Queryer, orgID OrgID, userID U // URNForURN will return a URN for the passed in URN including all the special query parameters // set that goflow and mailroom depend on. -func URNForURN(ctx context.Context, db Queryer, org *OrgAssets, u urns.URN) (urns.URN, error) { +func URNForURN(ctx context.Context, db Queryer, oa *OrgAssets, u urns.URN) (urns.URN, error) { urn := &ContactURN{} rows, err := db.QueryxContext(ctx, `SELECT row_to_json(r) FROM (SELECT id, scheme, path, display, auth, channel_id, priority FROM contacts_contacturn WHERE identity = $1 AND org_id = $2) r;`, - u.Identity(), org.OrgID(), + u.Identity(), oa.OrgID(), ) if err != nil { return urns.NilURN, errors.Errorf("error selecting URN: %s", u.Identity()) @@ -845,7 +844,7 @@ func URNForURN(ctx context.Context, db Queryer, org *OrgAssets, u urns.URN) (urn return urns.NilURN, errors.Errorf("no urn with identity: %s", u.Identity()) } - err = dbutil.ReadJSONRow(rows, urn) + err = dbutil.ScanJSON(rows, urn) if err != nil { return urns.NilURN, errors.Wrapf(err, "error loading contact urn") } @@ -854,13 +853,13 @@ func URNForURN(ctx context.Context, db Queryer, org *OrgAssets, u urns.URN) (urn return urns.NilURN, errors.Wrapf(err, "more than one URN returned for identity query") } - return urn.AsURN(org) + return urn.AsURN(oa) } // GetOrCreateURN will look up a URN by identity, creating it if needbe and associating it with the contact -func GetOrCreateURN(ctx context.Context, db Queryer, org *OrgAssets, contactID ContactID, u urns.URN) (urns.URN, error) { +func GetOrCreateURN(ctx context.Context, db Queryer, oa *OrgAssets, contactID ContactID, u urns.URN) (urns.URN, error) { // first try to get it directly - urn, _ := URNForURN(ctx, db, org, u) + urn, _ := URNForURN(ctx, db, oa, u) // found it? we are done if urn != urns.NilURN { @@ -876,7 +875,7 @@ func GetOrCreateURN(ctx context.Context, db Queryer, org *OrgAssets, contactID C Auth: GetURNAuth(u), Scheme: u.Scheme(), Priority: defaultURNPriority, - OrgID: org.OrgID(), + OrgID: oa.OrgID(), } _, err := db.NamedExecContext(ctx, insertContactURNsSQL, insert) @@ -885,13 +884,13 @@ func GetOrCreateURN(ctx context.Context, db Queryer, org *OrgAssets, contactID C } // do a lookup once more - return URNForURN(ctx, db, org, u) + return URNForURN(ctx, db, oa, u) } // URNForID will return a URN for the passed in ID including all the special query parameters // set that goflow and mailroom depend on. Generally this URN is built when loading a contact // but occasionally we need to load URNs one by one and this accomplishes that -func URNForID(ctx context.Context, db Queryer, org *OrgAssets, urnID URNID) (urns.URN, error) { +func URNForID(ctx context.Context, db Queryer, oa *OrgAssets, urnID URNID) (urns.URN, error) { urn := &ContactURN{} rows, err := db.QueryxContext(ctx, `SELECT row_to_json(r) FROM (SELECT id, scheme, path, display, auth, channel_id, priority FROM contacts_contacturn WHERE id = $1) r;`, @@ -906,17 +905,17 @@ func URNForID(ctx context.Context, db Queryer, org *OrgAssets, urnID URNID) (urn return urns.NilURN, errors.Errorf("no urn with id: %d", urnID) } - err = dbutil.ReadJSONRow(rows, urn) + err = dbutil.ScanJSON(rows, urn) if err != nil { return urns.NilURN, errors.Wrapf(err, "error loading contact urn") } - return urn.AsURN(org) + return urn.AsURN(oa) } // CalculateDynamicGroups recalculates all the dynamic groups for the passed in contact, recalculating // campaigns as necessary based on those group changes. -func CalculateDynamicGroups(ctx context.Context, db Queryer, org *OrgAssets, contacts []*flows.Contact) error { +func CalculateDynamicGroups(ctx context.Context, db Queryer, oa *OrgAssets, contacts []*flows.Contact) error { contactIDs := make([]ContactID, len(contacts)) groupAdds := make([]*GroupAdd, 0, 2*len(contacts)) groupRemoves := make([]*GroupRemove, 0, 2*len(contacts)) @@ -924,10 +923,10 @@ func CalculateDynamicGroups(ctx context.Context, db Queryer, org *OrgAssets, con for i, contact := range contacts { contactIDs[i] = ContactID(contact.ID()) - added, removed := contact.ReevaluateQueryBasedGroups(org.Env()) + added, removed := contact.ReevaluateQueryBasedGroups(oa.Env()) for _, a := range added { - group := org.GroupByUUID(a.UUID()) + group := oa.GroupByUUID(a.UUID()) if group != nil { groupAdds = append(groupAdds, &GroupAdd{ ContactID: ContactID(contact.ID()), @@ -936,13 +935,13 @@ func CalculateDynamicGroups(ctx context.Context, db Queryer, org *OrgAssets, con } // add in any campaigns we may qualify for - for _, campaign := range org.CampaignByGroupID(group.ID()) { + for _, campaign := range oa.CampaignByGroupID(group.ID()) { checkCampaigns[campaign] = append(checkCampaigns[campaign], contact) } } for _, r := range removed { - group := org.GroupByUUID(r.UUID()) + group := oa.GroupByUUID(r.UUID()) if group != nil { groupRemoves = append(groupRemoves, &GroupRemove{ ContactID: ContactID(contact.ID()), @@ -970,7 +969,7 @@ func CalculateDynamicGroups(ctx context.Context, db Queryer, org *OrgAssets, con // for each campaign figure out if we need to be added to any events fireAdds := make([]*FireAdd, 0, 2*len(contacts)) - tz := org.Env().Timezone() + tz := oa.Env().Timezone() now := time.Now() for campaign, eligibleContacts := range checkCampaigns { @@ -1089,7 +1088,7 @@ func GetURNAuth(urn urns.URN) null.String { return null.String(value) } -func GetURNChannelID(org *OrgAssets, urn urns.URN) ChannelID { +func GetURNChannelID(oa *OrgAssets, urn urns.URN) ChannelID { values, err := urn.Query() if err != nil { return NilChannelID @@ -1100,7 +1099,7 @@ func GetURNChannelID(org *OrgAssets, urn urns.URN) ChannelID { return NilChannelID } - channel := org.ChannelByUUID(assets.ChannelUUID(channelUUID)) + channel := oa.ChannelByUUID(assets.ChannelUUID(channelUUID)) if channel != nil { return channel.ID() } @@ -1153,7 +1152,7 @@ func UpdateContactLastSeenOn(ctx context.Context, db Queryer, contactID ContactI } // UpdateContactURNs updates the contact urns in our database to match the passed in changes -func UpdateContactURNs(ctx context.Context, db Queryer, org *OrgAssets, changes []*ContactURNsChanged) error { +func UpdateContactURNs(ctx context.Context, db Queryer, oa *OrgAssets, changes []*ContactURNsChanged) error { // keep track of all our inserts inserts := make([]interface{}, 0, len(changes)) @@ -1176,7 +1175,7 @@ func UpdateContactURNs(ctx context.Context, db Queryer, org *OrgAssets, changes // for each of our urns for _, urn := range change.URNs { // figure out if we have a channel - channelID := GetURNChannelID(org, urn) + channelID := GetURNChannelID(oa, urn) // do we have an id? urnID := URNID(GetURNInt(urn, "id")) @@ -1200,7 +1199,7 @@ func UpdateContactURNs(ctx context.Context, db Queryer, org *OrgAssets, changes Auth: GetURNAuth(urn), Scheme: urn.Scheme(), Priority: priority, - OrgID: org.OrgID(), + OrgID: oa.OrgID(), }) identities = append(identities, urn.Identity().String()) @@ -1230,7 +1229,7 @@ func UpdateContactURNs(ctx context.Context, db Queryer, org *OrgAssets, changes if len(inserts) > 0 { // find the unique ids of the contacts that may be affected by our URN inserts - orphanedIDs, err := queryContactIDs(ctx, db, `SELECT contact_id FROM contacts_contacturn WHERE identity = ANY($1) AND org_id = $2 AND contact_id IS NOT NULL`, pq.Array(identities), org.OrgID()) + orphanedIDs, err := queryContactIDs(ctx, db, `SELECT contact_id FROM contacts_contacturn WHERE identity = ANY($1) AND org_id = $2 AND contact_id IS NOT NULL`, pq.Array(identities), oa.OrgID()) if err != nil { return errors.Wrapf(err, "error finding contacts for URNs") } diff --git a/core/models/contacts_test.go b/core/models/contacts_test.go index 61f8a636b..8a34a0aca 100644 --- a/core/models/contacts_test.go +++ b/core/models/contacts_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/envs" @@ -14,7 +15,6 @@ import ( "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" "github.com/nyaruka/mailroom/utils/test" - "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -459,10 +459,10 @@ func TestStopContact(t *testing.T) { assert.NoError(t, err) // verify she's only in the stopped group - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contactgroup_contacts WHERE contact_id = $1`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contactgroup_contacts WHERE contact_id = $1`, testdata.Cathy.ID).Returns(1) // verify she's stopped - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'S' AND is_active = TRUE`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'S' AND is_active = TRUE`, testdata.Cathy.ID).Returns(1) } func TestUpdateContactLastSeenAndModifiedOn(t *testing.T) { @@ -478,7 +478,7 @@ func TestUpdateContactLastSeenAndModifiedOn(t *testing.T) { err = models.UpdateContactModifiedOn(ctx, db, []models.ContactID{testdata.Cathy.ID}) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE modified_on > $1 AND last_seen_on IS NULL`, t0).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE modified_on > $1 AND last_seen_on IS NULL`, t0).Returns(1) t1 := time.Now().Truncate(time.Millisecond) time.Sleep(time.Millisecond * 5) @@ -486,7 +486,7 @@ func TestUpdateContactLastSeenAndModifiedOn(t *testing.T) { err = models.UpdateContactLastSeenOn(ctx, db, testdata.Cathy.ID, t1) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE modified_on > $1 AND last_seen_on = $1`, t1).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE modified_on > $1 AND last_seen_on = $1`, t1).Returns(1) cathy, err := models.LoadContact(ctx, db, oa, testdata.Cathy.ID) require.NoError(t, err) @@ -517,17 +517,17 @@ func TestUpdateContactModifiedBy(t *testing.T) { err := models.UpdateContactModifiedBy(ctx, db, []models.ContactID{}, models.UserID(0)) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND modified_by_id = NULL`, testdata.Cathy.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND modified_by_id = NULL`, testdata.Cathy.ID).Returns(0) err = models.UpdateContactModifiedBy(ctx, db, []models.ContactID{testdata.Cathy.ID}, models.UserID(0)) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND modified_by_id = NULL`, testdata.Cathy.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND modified_by_id = NULL`, testdata.Cathy.ID).Returns(0) err = models.UpdateContactModifiedBy(ctx, db, []models.ContactID{testdata.Cathy.ID}, models.UserID(1)) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND modified_by_id = 1`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND modified_by_id = 1`, testdata.Cathy.ID).Returns(1) } func TestUpdateContactStatus(t *testing.T) { @@ -538,8 +538,8 @@ func TestUpdateContactStatus(t *testing.T) { err := models.UpdateContactStatus(ctx, db, []*models.ContactStatusChange{}) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'B'`, testdata.Cathy.ID).Returns(0) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'S'`, testdata.Cathy.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'B'`, testdata.Cathy.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'S'`, testdata.Cathy.ID).Returns(0) changes := make([]*models.ContactStatusChange, 0, 1) changes = append(changes, &models.ContactStatusChange{testdata.Cathy.ID, flows.ContactStatusBlocked}) @@ -547,8 +547,8 @@ func TestUpdateContactStatus(t *testing.T) { err = models.UpdateContactStatus(ctx, db, changes) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'B'`, testdata.Cathy.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'S'`, testdata.Cathy.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'B'`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'S'`, testdata.Cathy.ID).Returns(0) changes = make([]*models.ContactStatusChange, 0, 1) changes = append(changes, &models.ContactStatusChange{testdata.Cathy.ID, flows.ContactStatusStopped}) @@ -556,8 +556,8 @@ func TestUpdateContactStatus(t *testing.T) { err = models.UpdateContactStatus(ctx, db, changes) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'B'`, testdata.Cathy.ID).Returns(0) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'S'`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'B'`, testdata.Cathy.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'S'`, testdata.Cathy.ID).Returns(1) } @@ -597,15 +597,15 @@ func TestUpdateContactURNs(t *testing.T) { assert.NoError(t, err) assertContactURNs(testdata.Bob.ID, []string{"tel:+16055742222", "tel:+16055700002"}) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contacturn WHERE contact_id IS NULL`).Returns(0) // shouldn't be any orphan URNs - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contacturn`).Returns(numInitialURNs + 2) // but 2 new URNs + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contacturn WHERE contact_id IS NULL`).Returns(0) // shouldn't be any orphan URNs + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contacturn`).Returns(numInitialURNs + 2) // but 2 new URNs // remove a URN from Cathy err = models.UpdateContactURNs(ctx, db, oa, []*models.ContactURNsChanged{{testdata.Cathy.ID, testdata.Org1.ID, []urns.URN{"tel:+16055700001"}}}) assert.NoError(t, err) assertContactURNs(testdata.Cathy.ID, []string{"tel:+16055700001"}) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contacturn WHERE contact_id IS NULL`).Returns(1) // now orphaned + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contacturn WHERE contact_id IS NULL`).Returns(1) // now orphaned // steal a URN from Bob err = models.UpdateContactURNs(ctx, db, oa, []*models.ContactURNsChanged{{testdata.Cathy.ID, testdata.Org1.ID, []urns.URN{"tel:+16055700001", "tel:+16055700002"}}}) @@ -626,5 +626,5 @@ func TestUpdateContactURNs(t *testing.T) { assertContactURNs(testdata.Bob.ID, []string{"tel:+16055742222", "tel:+16055700002"}) assertContactURNs(testdata.George.ID, []string{"tel:+16055743333"}) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contacturn`).Returns(numInitialURNs + 3) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contacturn`).Returns(numInitialURNs + 3) } diff --git a/core/models/fields.go b/core/models/fields.go index 2468023ea..a0a520ca4 100644 --- a/core/models/fields.go +++ b/core/models/fields.go @@ -4,10 +4,9 @@ import ( "context" "time" - "github.com/nyaruka/goflow/assets" - "github.com/nyaruka/mailroom/utils/dbutil" - "github.com/jmoiron/sqlx" + "github.com/nyaruka/gocommon/dbutil" + "github.com/nyaruka/goflow/assets" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -60,7 +59,7 @@ func loadFields(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets.Fie for rows.Next() { field := &Field{} - err = dbutil.ReadJSONRow(rows, &field.f) + err = dbutil.ScanJSON(rows, &field.f) if err != nil { return nil, nil, errors.Wrap(err, "error reading field") } diff --git a/core/models/flow_stats.go b/core/models/flow_stats.go new file mode 100644 index 000000000..0f112739a --- /dev/null +++ b/core/models/flow_stats.go @@ -0,0 +1,104 @@ +package models + +import ( + "context" + "fmt" + "time" + + "github.com/buger/jsonparser" + "github.com/jmoiron/sqlx" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/utils" + "github.com/nyaruka/mailroom/runtime" + "github.com/nyaruka/redisx" + "github.com/pkg/errors" +) + +const ( + recentContactsCap = 5 // number of recent contacts we keep per segment + recentContactsExpire = time.Hour * 24 // how long we keep recent contacts + recentContactsKey = "recent_contacts:%s" +) + +var storeOperandsForTypes = map[string]bool{"wait_for_response": true, "split_by_expression": true, "split_by_contact_field": true, "split_by_run_result": true} + +type segmentID struct { + exitUUID flows.ExitUUID + destUUID flows.NodeUUID +} + +func (s segmentID) String() string { + return fmt.Sprintf("%s:%s", s.exitUUID, s.destUUID) +} + +type segmentContact struct { + contact *flows.Contact + operand string + time time.Time +} + +// RecordFlowStatistics records statistics from the given parallel slices of sessions and sprints +func RecordFlowStatistics(ctx context.Context, rt *runtime.Runtime, tx *sqlx.Tx, sessions []flows.Session, sprints []flows.Sprint) error { + rc := rt.RP.Get() + defer rc.Close() + + segmentIDs := make([]segmentID, 0, 10) + recentBySegment := make(map[segmentID][]*segmentContact, 10) + nodeTypeCache := make(map[flows.NodeUUID]string) + + for i, sprint := range sprints { + session := sessions[i] + + for _, seg := range sprint.Segments() { + segID := segmentID{seg.Exit().UUID(), seg.Destination().UUID()} + uiNodeType := getNodeUIType(seg.Flow(), seg.Node(), nodeTypeCache) + + // only store operand values for certain node types + operand := "" + if storeOperandsForTypes[uiNodeType] { + operand = seg.Operand() + } + + if _, seen := recentBySegment[segID]; !seen { + segmentIDs = append(segmentIDs, segID) + } + recentBySegment[segID] = append(recentBySegment[segID], &segmentContact{contact: session.Contact(), operand: operand, time: seg.Time()}) + } + } + + for _, segID := range segmentIDs { + recentContacts := recentBySegment[segID] + + // trim recent set for each segment - no point in trying to add more values than we keep + if len(recentContacts) > recentContactsCap { + recentBySegment[segID] = recentContacts[:len(recentContacts)-recentContactsCap] + } + + recentSet := redisx.NewCappedZSet(fmt.Sprintf(recentContactsKey, segID), recentContactsCap, recentContactsExpire) + + for _, recent := range recentContacts { + // set members need to be unique, so we include a random string + value := fmt.Sprintf("%s|%d|%s", redisx.RandomBase64(10), recent.contact.ID(), utils.TruncateEllipsis(recent.operand, 100)) + score := float64(recent.time.UnixNano()) / float64(1e9) // score is UNIX time as floating point + + err := recentSet.Add(rc, value, score) + if err != nil { + return errors.Wrap(err, "error adding recent contact to set") + } + } + } + + return nil +} + +func getNodeUIType(flow flows.Flow, node flows.Node, cache map[flows.NodeUUID]string) string { + uiType, cached := cache[node.UUID()] + if cached { + return uiType + } + + // try to lookup node type but don't error if we can't find it.. could be a bad flow + value, _ := jsonparser.GetString(flow.UI(), "nodes", string(node.UUID()), "type") + cache[node.UUID()] = value + return value +} diff --git a/core/models/flow_stats_test.go b/core/models/flow_stats_test.go new file mode 100644 index 000000000..a02aa406d --- /dev/null +++ b/core/models/flow_stats_test.go @@ -0,0 +1,84 @@ +package models_test + +import ( + "os" + "testing" + + "github.com/nyaruka/gocommon/random" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/test" + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/testsuite" + "github.com/nyaruka/redisx/assertredis" + "github.com/stretchr/testify/require" +) + +func TestRecordFlowStatistics(t *testing.T) { + ctx, rt, _, rp := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetRedis) + + defer random.SetGenerator(random.DefaultGenerator) + random.SetGenerator(random.NewSeededGenerator(123)) + + assetsJSON, err := os.ReadFile("testdata/flow_stats_test.json") + require.NoError(t, err) + + session1, session1Sprint1 := test.NewSessionBuilder().WithAssets(assetsJSON).WithFlow("19eab6aa-4a88-42a1-8882-b9956823c680"). + WithContact("4ad4f0a6-fb95-4845-b4cb-335f67eafe96", 123, "Bob", "eng", "").MustBuild() + session2, session2Sprint1 := test.NewSessionBuilder().WithAssets(assetsJSON).WithFlow("19eab6aa-4a88-42a1-8882-b9956823c680"). + WithContact("5cfe8b70-0d4a-4862-8fb5-e72603d832a9", 234, "Ann", "eng", "").MustBuild() + session3, session3Sprint1 := test.NewSessionBuilder().WithAssets(assetsJSON).WithFlow("19eab6aa-4a88-42a1-8882-b9956823c680"). + WithContact("367c8ef2-aac7-4264-9a03-40877371995d", 345, "Jim", "eng", "").MustBuild() + + err = models.RecordFlowStatistics(ctx, rt, nil, []flows.Session{session1, session2, session3}, []flows.Sprint{session1Sprint1, session2Sprint1, session3Sprint1}) + require.NoError(t, err) + + assertredis.Keys(t, rp, []string{ + "recent_contacts:5fd2e537-0534-4c12-8425-bef87af09d46:072b95b3-61c3-4e0e-8dd1-eb7481083f94", // "what's your fav color" -> color split + }) + + // all 3 contacts went from first msg to the color split - no operands recorded for this segment + assertredis.ZRange(t, rp, "recent_contacts:5fd2e537-0534-4c12-8425-bef87af09d46:072b95b3-61c3-4e0e-8dd1-eb7481083f94", 0, -1, + []string{"LZbbzXDPJH|123|", "reuPYVP90u|234|", "qWARtWDACk|345|"}, + ) + + _, session1Sprint2, err := test.ResumeSession(session1, assetsJSON, "blue") + require.NoError(t, err) + _, session2Sprint2, err := test.ResumeSession(session2, assetsJSON, "BLUE") + require.NoError(t, err) + session3, session3Sprint2, err := test.ResumeSession(session3, assetsJSON, "teal") + require.NoError(t, err) + _, session3Sprint3, err := test.ResumeSession(session3, assetsJSON, "azure") + require.NoError(t, err) + + err = models.RecordFlowStatistics(ctx, rt, nil, []flows.Session{session1, session2, session3}, []flows.Sprint{session1Sprint2, session2Sprint2, session3Sprint2}) + require.NoError(t, err) + err = models.RecordFlowStatistics(ctx, rt, nil, []flows.Session{session3}, []flows.Sprint{session3Sprint3}) + require.NoError(t, err) + + assertredis.Keys(t, rp, []string{ + "recent_contacts:5fd2e537-0534-4c12-8425-bef87af09d46:072b95b3-61c3-4e0e-8dd1-eb7481083f94", // "what's your fav color" -> color split + "recent_contacts:c02fc3ba-369a-4c87-9bc4-c3b376bda6d2:57b50d33-2b5a-4726-82de-9848c61eff6e", // color split :: Blue exit -> next node + "recent_contacts:ea6c38dc-11e2-4616-9f3e-577e44765d44:8712db6b-25ff-4789-892c-581f24eeeb95", // color split :: Other exit -> next node + "recent_contacts:2b698218-87e5-4ab8-922e-e65f91d12c10:88d8bf00-51ce-4e5e-aae8-4f957a0761a0", // split by expression :: Other exit -> next node + "recent_contacts:0a4f2ea9-c47f-4e9c-a242-89ae5b38d679:072b95b3-61c3-4e0e-8dd1-eb7481083f94", // "sorry I don't know that color" -> color split + "recent_contacts:97cd44ce-dec2-4e19-8ca2-4e20db51dc08:0e1fe072-6f03-4f29-98aa-7bedbe930dab", // "X is a great color" -> split by expression + "recent_contacts:614e7451-e0bd-43d9-b317-2aded3c8d790:a1e649db-91e0-47c4-ab14-eba0d1475116", // "you have X tickets" -> group split + }) + + // check recent operands for color split :: Blue exit -> next node + assertredis.ZRange(t, rp, "recent_contacts:c02fc3ba-369a-4c87-9bc4-c3b376bda6d2:57b50d33-2b5a-4726-82de-9848c61eff6e", 0, -1, + []string{"2SS5dyuJzp|123|blue", "6MBPV0gqT9|234|BLUE"}, + ) + + // check recent operands for color split :: Other exit -> next node + assertredis.ZRange(t, rp, "recent_contacts:ea6c38dc-11e2-4616-9f3e-577e44765d44:8712db6b-25ff-4789-892c-581f24eeeb95", 0, -1, + []string{"uI8bPiuaeA|345|teal", "2Vz/MpdX9s|345|azure"}, + ) + + // check recent operands for split by expression :: Other exit -> next node + assertredis.ZRange(t, rp, "recent_contacts:2b698218-87e5-4ab8-922e-e65f91d12c10:88d8bf00-51ce-4e5e-aae8-4f957a0761a0", 0, -1, + []string{"2MsZZ/N3TH|123|0", "KKLrT60Tr9|234|0"}, + ) +} diff --git a/core/models/flows.go b/core/models/flows.go index ebc65be60..811d95d0d 100644 --- a/core/models/flows.go +++ b/core/models/flows.go @@ -6,9 +6,9 @@ import ( "encoding/json" "time" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/flows" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/nyaruka/null" "github.com/jmoiron/sqlx" @@ -53,6 +53,7 @@ var flowTypeMapping = map[flows.FlowType]FlowType{ type Flow struct { f struct { ID FlowID `json:"id"` + OrgID OrgID `json:"org_id"` UUID assets.FlowUUID `json:"uuid"` Name string `json:"name"` Config null.Map `json:"config"` @@ -66,6 +67,9 @@ type Flow struct { // ID returns the ID for this flow func (f *Flow) ID() FlowID { return f.f.ID } +// OrgID returns the Org ID for this flow +func (f *Flow) OrgID() OrgID { return f.f.OrgID } + // UUID returns the UUID for this flow func (f *Flow) UUID() assets.FlowUUID { return f.f.UUID } @@ -102,8 +106,8 @@ func (f *Flow) IVRRetryWait() *time.Duration { // IgnoreTriggers returns whether this flow ignores triggers func (f *Flow) IgnoreTriggers() bool { return f.f.IgnoreTriggers } -// FlowReference return a flow reference for this flow -func (f *Flow) FlowReference() *assets.FlowReference { +// Reference return a flow reference for this flow +func (f *Flow) Reference() *assets.FlowReference { return assets.NewFlowReference(f.UUID(), f.Name()) } @@ -151,7 +155,7 @@ func loadFlow(ctx context.Context, db Queryer, sql string, orgID OrgID, arg inte return nil, nil } - err = dbutil.ReadJSONRow(rows, &flow.f) + err = dbutil.ScanJSON(rows, &flow.f) if err != nil { return nil, errors.Wrapf(err, "error reading flow definition by: %s", arg) } @@ -164,6 +168,7 @@ func loadFlow(ctx context.Context, db Queryer, sql string, orgID OrgID, arg inte const selectFlowByUUIDSQL = ` SELECT ROW_TO_JSON(r) FROM (SELECT id, + org_id, uuid, name, ignore_triggers, @@ -174,8 +179,13 @@ SELECT ROW_TO_JSON(r) FROM (SELECT jsonb_build_object( 'name', f.name, 'uuid', f.uuid, - 'flow_type', f.flow_type, - 'expire_after_minutes', f.expires_after_minutes, + 'flow_type', f.flow_type, + 'expire_after_minutes', + CASE f.flow_type + WHEN 'M' THEN GREATEST(5, LEAST(f.expires_after_minutes, 43200)) + WHEN 'V' THEN GREATEST(1, LEAST(f.expires_after_minutes, 15)) + ELSE 0 + END, 'metadata', jsonb_build_object( 'uuid', f.uuid, 'id', f.id, @@ -211,6 +221,7 @@ WHERE const selectFlowByIDSQL = ` SELECT ROW_TO_JSON(r) FROM (SELECT id, + org_id, uuid, name, ignore_triggers, @@ -222,7 +233,12 @@ SELECT ROW_TO_JSON(r) FROM (SELECT 'name', f.name, 'uuid', f.uuid, 'flow_type', f.flow_type, - 'expire_after_minutes', f.expires_after_minutes, + 'expire_after_minutes', + CASE f.flow_type + WHEN 'M' THEN GREATEST(5, LEAST(f.expires_after_minutes, 43200)) + WHEN 'V' THEN GREATEST(1, LEAST(f.expires_after_minutes, 15)) + ELSE 0 + END, 'metadata', jsonb_build_object( 'uuid', f.uuid, 'id', f.id, diff --git a/core/models/flows_test.go b/core/models/flows_test.go index a0467dc04..ee4385a29 100644 --- a/core/models/flows_test.go +++ b/core/models/flows_test.go @@ -1,10 +1,12 @@ package models_test import ( + "fmt" "testing" "time" "github.com/nyaruka/goflow/assets" + "github.com/nyaruka/goflow/flows" "github.com/nyaruka/mailroom/core/goflow" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/testsuite" @@ -15,55 +17,119 @@ import ( func TestLoadFlows(t *testing.T) { ctx, rt, db, _ := testsuite.Get() + defer testsuite.Reset(testsuite.ResetAll) + db.MustExec(`UPDATE flows_flow SET metadata = '{"ivr_retry": 30}'::json WHERE id = $1`, testdata.IVRFlow.ID) db.MustExec(`UPDATE flows_flow SET metadata = '{"ivr_retry": -1}'::json WHERE id = $1`, testdata.SurveyorFlow.ID) + db.MustExec(`UPDATE flows_flow SET expires_after_minutes = 720 WHERE id = $1`, testdata.Favorites.ID) + db.MustExec(`UPDATE flows_flow SET expires_after_minutes = 1 WHERE id = $1`, testdata.PickANumber.ID) // too small for messaging + db.MustExec(`UPDATE flows_flow SET expires_after_minutes = 12345678 WHERE id = $1`, testdata.SingleMessage.ID) // too large for messaging + db.MustExec(`UPDATE flows_flow SET expires_after_minutes = 123 WHERE id = $1`, testdata.SurveyorFlow.ID) // surveyor flows shouldn't have expires sixtyMinutes := 60 * time.Minute thirtyMinutes := 30 * time.Minute - tcs := []struct { - org *testdata.Org - flowID models.FlowID - flowUUID assets.FlowUUID - expectedName string - expectedIVRRetry *time.Duration - }{ - {testdata.Org1, testdata.Favorites.ID, testdata.Favorites.UUID, "Favorites", &sixtyMinutes}, // will use default IVR retry - {testdata.Org1, testdata.IVRFlow.ID, testdata.IVRFlow.UUID, "IVR Flow", &thirtyMinutes}, // will have explicit IVR retry - {testdata.Org1, testdata.SurveyorFlow.ID, testdata.SurveyorFlow.UUID, "Contact Surveyor", nil}, // will have no IVR retry - {testdata.Org2, models.FlowID(0), assets.FlowUUID("51e3c67d-8483-449c-abf7-25e50686f0db"), "", nil}, + type testcase struct { + org *testdata.Org + flowID models.FlowID + flowUUID assets.FlowUUID + expectedName string + expectedType models.FlowType + expectedEngineType flows.FlowType + expectedExpire int + expectedIVRRetry *time.Duration } - for i, tc := range tcs { - // test loading by UUID - flow, err := models.LoadFlowByUUID(ctx, db, tc.org.ID, tc.flowUUID) + tcs := []testcase{ + { + testdata.Org1, + testdata.Favorites.ID, + testdata.Favorites.UUID, + "Favorites", + models.FlowTypeMessaging, + flows.FlowTypeMessaging, + 720, + &sixtyMinutes, // uses default + }, + { + testdata.Org1, + testdata.PickANumber.ID, + testdata.PickANumber.UUID, + "Pick a Number", + models.FlowTypeMessaging, + flows.FlowTypeMessaging, + 5, // clamped to minimum + &sixtyMinutes, // uses default + }, + { + testdata.Org1, + testdata.SingleMessage.ID, + testdata.SingleMessage.UUID, + "Send All", + models.FlowTypeMessaging, + flows.FlowTypeMessaging, + 43200, // clamped to maximum + &sixtyMinutes, // uses default + }, + { + testdata.Org1, + testdata.IVRFlow.ID, + testdata.IVRFlow.UUID, + "IVR Flow", + models.FlowTypeVoice, + flows.FlowTypeVoice, + 5, + &thirtyMinutes, // uses explicit + }, + { + testdata.Org1, + testdata.SurveyorFlow.ID, + testdata.SurveyorFlow.UUID, + "Contact Surveyor", + models.FlowTypeSurveyor, + flows.FlowTypeMessagingOffline, + 0, // explicit ignored + nil, // no retry + }, + } + + assertFlow := func(tc *testcase, dbFlow *models.Flow) { + desc := fmt.Sprintf("flow id=%d uuid=%s", tc.flowID, tc.flowUUID) + + // check properties of flow model + assert.Equal(t, tc.flowID, dbFlow.ID()) + assert.Equal(t, tc.flowUUID, dbFlow.UUID()) + assert.Equal(t, tc.expectedName, dbFlow.Name(), "db name mismatch for %s", desc) + assert.Equal(t, tc.expectedIVRRetry, dbFlow.IVRRetryWait(), "db IVR retry mismatch for %s", desc) + + // load as engine flow and check that too + flow, err := goflow.ReadFlow(rt.Config, dbFlow.Definition()) assert.NoError(t, err) - if tc.expectedName != "" { - assert.Equal(t, tc.flowID, flow.ID()) - assert.Equal(t, tc.flowUUID, flow.UUID()) - assert.Equal(t, tc.expectedName, flow.Name(), "%d: name mismatch", i) - assert.Equal(t, tc.expectedIVRRetry, flow.IVRRetryWait(), "%d: IVR retry mismatch", i) + assert.Equal(t, tc.flowUUID, flow.UUID(), "engine UUID mismatch for %s", desc) + assert.Equal(t, tc.flowUUID, flow.UUID(), "engine UUID mismatch for %s", desc) + assert.Equal(t, tc.expectedName, flow.Name(), "engine name mismatch for %s", desc) + assert.Equal(t, tc.expectedEngineType, flow.Type(), "engine type mismatch for %s", desc) + assert.Equal(t, tc.expectedExpire, flow.ExpireAfterMinutes(), "engine expire mismatch for %s", desc) - _, err := goflow.ReadFlow(rt.Config, flow.Definition()) - assert.NoError(t, err) - } else { - assert.Nil(t, flow) - } + } - // test loading by ID - flow, err = models.LoadFlowByID(ctx, db, tc.org.ID, tc.flowID) + for _, tc := range tcs { + // test loading by UUID + dbFlow, err := models.LoadFlowByUUID(ctx, db, tc.org.ID, tc.flowUUID) assert.NoError(t, err) + assertFlow(&tc, dbFlow) - if tc.expectedName != "" { - assert.Equal(t, tc.flowID, flow.ID()) - assert.Equal(t, tc.flowUUID, flow.UUID()) - assert.Equal(t, tc.expectedName, flow.Name(), "%d: name mismatch", i) - assert.Equal(t, tc.expectedIVRRetry, flow.IVRRetryWait(), "%d: IVR retry mismatch", i) - } else { - assert.Nil(t, flow) - } + // test loading by ID + dbFlow, err = models.LoadFlowByID(ctx, db, tc.org.ID, tc.flowID) + assert.NoError(t, err) + assertFlow(&tc, dbFlow) } + + // test loading flow with wrong org + dbFlow, err := models.LoadFlowByID(ctx, db, testdata.Org2.ID, testdata.Favorites.ID) + assert.NoError(t, err) + assert.Nil(t, dbFlow) } func TestFlowIDForUUID(t *testing.T) { diff --git a/core/models/globals.go b/core/models/globals.go index 18baad2cb..18baa7fa4 100644 --- a/core/models/globals.go +++ b/core/models/globals.go @@ -5,8 +5,8 @@ import ( "encoding/json" "time" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/goflow/assets" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/jmoiron/sqlx" "github.com/pkg/errors" @@ -44,7 +44,7 @@ func loadGlobals(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets.Gl globals := make([]assets.Global, 0) for rows.Next() { global := &Global{} - err = dbutil.ReadJSONRow(rows, &global.g) + err = dbutil.ScanAndValidateJSON(rows, &global.g) if err != nil { return nil, errors.Wrap(err, "error reading global row") } diff --git a/core/models/groups.go b/core/models/groups.go index ce8ab590b..f1550903d 100644 --- a/core/models/groups.go +++ b/core/models/groups.go @@ -4,9 +4,9 @@ import ( "context" "time" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/flows" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/jmoiron/sqlx" "github.com/lib/pq" @@ -62,7 +62,7 @@ func LoadGroups(ctx context.Context, db Queryer, orgID OrgID) ([]assets.Group, e groups := make([]assets.Group, 0, 10) for rows.Next() { group := &Group{} - err = dbutil.ReadJSONRow(rows, &group.g) + err = dbutil.ScanJSON(rows, &group.g) if err != nil { return nil, errors.Wrap(err, "error reading group row") } @@ -192,7 +192,7 @@ func UpdateGroupStatus(ctx context.Context, db Queryer, groupID GroupID, status // RemoveContactsFromGroupAndCampaigns removes the passed in contacts from the passed in group, taking care of also // removing them from any associated campaigns -func RemoveContactsFromGroupAndCampaigns(ctx context.Context, db *sqlx.DB, org *OrgAssets, groupID GroupID, contactIDs []ContactID) error { +func RemoveContactsFromGroupAndCampaigns(ctx context.Context, db *sqlx.DB, oa *OrgAssets, groupID GroupID, contactIDs []ContactID) error { removeBatch := func(batch []ContactID) error { tx, err := db.BeginTxx(ctx, nil) @@ -215,7 +215,7 @@ func RemoveContactsFromGroupAndCampaigns(ctx context.Context, db *sqlx.DB, org * } // remove from any campaign events - err = DeleteUnfiredEventsForGroupRemoval(ctx, tx, org, batch, groupID) + err = DeleteUnfiredEventsForGroupRemoval(ctx, tx, oa, batch, groupID) if err != nil { tx.Rollback() return errors.Wrapf(err, "error removing contacts from unfired campaign events for group: %d", groupID) @@ -254,7 +254,7 @@ func RemoveContactsFromGroupAndCampaigns(ctx context.Context, db *sqlx.DB, org * // AddContactsToGroupAndCampaigns takes care of adding the passed in contacts to the passed in group, updating any // associated campaigns as needed -func AddContactsToGroupAndCampaigns(ctx context.Context, db *sqlx.DB, org *OrgAssets, groupID GroupID, contactIDs []ContactID) error { +func AddContactsToGroupAndCampaigns(ctx context.Context, db *sqlx.DB, oa *OrgAssets, groupID GroupID, contactIDs []ContactID) error { // we need session assets in order to recalculate campaign events addBatch := func(batch []ContactID) error { tx, err := db.BeginTxx(ctx, nil) @@ -278,7 +278,7 @@ func AddContactsToGroupAndCampaigns(ctx context.Context, db *sqlx.DB, org *OrgAs } // now load our contacts and add update their campaign events - contacts, err := LoadContacts(ctx, tx, org, batch) + contacts, err := LoadContacts(ctx, tx, oa, batch) if err != nil { tx.Rollback() return errors.Wrapf(err, "error loading contacts when adding to group: %d", groupID) @@ -287,7 +287,7 @@ func AddContactsToGroupAndCampaigns(ctx context.Context, db *sqlx.DB, org *OrgAs // convert to flow contacts fcs := make([]*flows.Contact, len(contacts)) for i, c := range contacts { - fcs[i], err = c.FlowContact(org) + fcs[i], err = c.FlowContact(oa) if err != nil { tx.Rollback() return errors.Wrapf(err, "error converting contact to flow contact: %s", c.UUID()) @@ -295,7 +295,7 @@ func AddContactsToGroupAndCampaigns(ctx context.Context, db *sqlx.DB, org *OrgAs } // schedule any upcoming events that were affected by this group - err = AddCampaignEventsForGroupAddition(ctx, tx, org, fcs, groupID) + err = AddCampaignEventsForGroupAddition(ctx, tx, oa, fcs, groupID) if err != nil { tx.Rollback() return errors.Wrapf(err, "error calculating new campaign events during group addition: %d", groupID) @@ -334,7 +334,7 @@ func AddContactsToGroupAndCampaigns(ctx context.Context, db *sqlx.DB, org *OrgAs // PopulateDynamicGroup calculates which members should be part of a group and populates the contacts // for that group by performing the minimum number of inserts / deletes. -func PopulateDynamicGroup(ctx context.Context, db *sqlx.DB, es *elastic.Client, org *OrgAssets, groupID GroupID, query string) (int, error) { +func PopulateDynamicGroup(ctx context.Context, db *sqlx.DB, es *elastic.Client, oa *OrgAssets, groupID GroupID, query string) (int, error) { err := UpdateGroupStatus(ctx, db, groupID, GroupStatusEvaluating) if err != nil { return 0, errors.Wrapf(err, "error marking dynamic group as evaluating") @@ -345,9 +345,9 @@ func PopulateDynamicGroup(ctx context.Context, db *sqlx.DB, es *elastic.Client, // we have a bit of a race with the indexer process.. we want to make sure that any contacts that changed // before this group was updated but after the last index are included, so if a contact was modified // more recently than 10 seconds ago, we wait that long before starting in populating our group - newest, err := GetNewestContactModifiedOn(ctx, db, org) + newest, err := GetNewestContactModifiedOn(ctx, db, oa) if err != nil { - return 0, errors.Wrapf(err, "error getting most recent contact modified_on for org: %d", org.OrgID()) + return 0, errors.Wrapf(err, "error getting most recent contact modified_on for org: %d", oa.OrgID()) } if newest != nil { n := *newest @@ -371,7 +371,7 @@ func PopulateDynamicGroup(ctx context.Context, db *sqlx.DB, es *elastic.Client, } // calculate new set of ids - new, err := ContactIDsForQuery(ctx, es, org, query) + new, err := ContactIDsForQuery(ctx, es, oa, query) if err != nil { return 0, errors.Wrapf(err, "error performing query: %s for group: %d", query, groupID) } @@ -392,13 +392,13 @@ func PopulateDynamicGroup(ctx context.Context, db *sqlx.DB, es *elastic.Client, } // first remove all the contacts - err = RemoveContactsFromGroupAndCampaigns(ctx, db, org, groupID, removals) + err = RemoveContactsFromGroupAndCampaigns(ctx, db, oa, groupID, removals) if err != nil { return 0, errors.Wrapf(err, "error removing contacts from group: %d", groupID) } // then add them all - err = AddContactsToGroupAndCampaigns(ctx, db, org, groupID, adds) + err = AddContactsToGroupAndCampaigns(ctx, db, oa, groupID, adds) if err != nil { return 0, errors.Wrapf(err, "error adding contacts to group: %d", groupID) } diff --git a/core/models/groups_test.go b/core/models/groups_test.go index c57da761f..b4339dc73 100644 --- a/core/models/groups_test.go +++ b/core/models/groups_test.go @@ -5,7 +5,7 @@ import ( "fmt" "testing" - "github.com/nyaruka/gocommon/uuids" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/testsuite" @@ -60,12 +60,7 @@ func TestDynamicGroups(t *testing.T) { defer testsuite.Reset(testsuite.ResetAll) // insert an event on our campaign - var eventID models.CampaignEventID - db.Get(&eventID, - `INSERT INTO campaigns_campaignevent(is_active, created_on, modified_on, uuid, "offset", unit, event_type, delivery_hour, - campaign_id, created_by_id, modified_by_id, flow_id, relative_to_id, start_mode) - VALUES(TRUE, NOW(), NOW(), $1, 1000, 'W', 'F', -1, $2, 1, 1, $3, $4, 'I') RETURNING id`, - uuids.New(), testdata.RemindersCampaign.ID, testdata.Favorites.ID, testdata.JoinedField.ID) + newEvent := testdata.InsertCampaignFlowEvent(db, testdata.RemindersCampaign, testdata.Favorites, testdata.JoinedField, 1000, "W") // clear Cathy's value db.MustExec( @@ -163,13 +158,13 @@ func TestDynamicGroups(t *testing.T) { assert.NoError(t, err) assert.Equal(t, tc.ContactIDs, contactIDs) - testsuite.AssertQuery(t, db, `SELECT count(*) from contacts_contactgroup WHERE id = $1 AND status = 'R'`, testdata.DoctorsGroup.ID). + assertdb.Query(t, db, `SELECT count(*) from contacts_contactgroup WHERE id = $1 AND status = 'R'`, testdata.DoctorsGroup.ID). Returns(1, "wrong number of contacts in group for query: %s", tc.Query) - testsuite.AssertQuery(t, db, `SELECT count(*) from campaigns_eventfire WHERE event_id = $1`, eventID). + assertdb.Query(t, db, `SELECT count(*) from campaigns_eventfire WHERE event_id = $1`, newEvent.ID). Returns(len(tc.EventContactIDs), "wrong number of contacts with events for query: %s", tc.Query) - testsuite.AssertQuery(t, db, `SELECT count(*) from campaigns_eventfire WHERE event_id = $1 AND contact_id = ANY($2)`, eventID, pq.Array(tc.EventContactIDs)). + assertdb.Query(t, db, `SELECT count(*) from campaigns_eventfire WHERE event_id = $1 AND contact_id = ANY($2)`, newEvent.ID, pq.Array(tc.EventContactIDs)). Returns(len(tc.EventContactIDs), "wrong contacts with events for query: %s", tc.Query) } } diff --git a/core/models/http_logs_test.go b/core/models/http_logs_test.go index 4ccf9385c..922f1ef91 100644 --- a/core/models/http_logs_test.go +++ b/core/models/http_logs_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/httpx" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/mailroom/core/models" @@ -25,21 +26,21 @@ func TestHTTPLogs(t *testing.T) { err := models.InsertHTTPLogs(ctx, db, []*models.HTTPLog{log}) assert.Nil(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) from request_logs_httplog WHERE org_id = $1 AND status_code = 200 AND classifier_id = $2 AND is_error = FALSE`, testdata.Org1.ID, testdata.Wit.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) from request_logs_httplog WHERE org_id = $1 AND status_code = 200 AND classifier_id = $2 AND is_error = FALSE`, testdata.Org1.ID, testdata.Wit.ID).Returns(1) // insert a log with nil response log = models.NewClassifierCalledLog(testdata.Org1.ID, testdata.Wit.ID, "http://foo.bar", 0, "GET /", "", true, time.Second, 0, time.Now()) err = models.InsertHTTPLogs(ctx, db, []*models.HTTPLog{log}) assert.Nil(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) from request_logs_httplog WHERE org_id = $1 AND status_code = 0 AND classifier_id = $2 AND is_error = TRUE AND response IS NULL`, testdata.Org1.ID, testdata.Wit.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) from request_logs_httplog WHERE org_id = $1 AND status_code = 0 AND classifier_id = $2 AND is_error = TRUE AND response IS NULL`, testdata.Org1.ID, testdata.Wit.ID).Returns(1) // insert a webhook log log = models.NewWebhookCalledLog(testdata.Org1.ID, testdata.Favorites.ID, "http://foo.bar", 400, "GET /", "HTTP 200", false, time.Second, 2, time.Now()) err = models.InsertHTTPLogs(ctx, db, []*models.HTTPLog{log}) assert.Nil(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) from request_logs_httplog WHERE org_id = $1 AND status_code = 400 AND flow_id = $2 AND num_retries = 2`, testdata.Org1.ID, testdata.Favorites.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) from request_logs_httplog WHERE org_id = $1 AND status_code = 400 AND flow_id = $2 AND num_retries = 2`, testdata.Org1.ID, testdata.Favorites.ID).Returns(1) } func TestHTTPLogger(t *testing.T) { @@ -75,5 +76,5 @@ func TestHTTPLogger(t *testing.T) { err = logger.Insert(ctx, db) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) from request_logs_httplog WHERE org_id = $1 AND ticketer_id = $2`, testdata.Org1.ID, testdata.Mailgun.ID).Returns(2) + assertdb.Query(t, db, `SELECT count(*) from request_logs_httplog WHERE org_id = $1 AND ticketer_id = $2`, testdata.Org1.ID, testdata.Mailgun.ID).Returns(2) } diff --git a/core/models/imports_test.go b/core/models/imports_test.go index 9a9bb56e5..9fbc6a0a0 100644 --- a/core/models/imports_test.go +++ b/core/models/imports_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/jsonx" "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/gocommon/uuids" @@ -46,8 +47,7 @@ func TestContactImports(t *testing.T) { // give our org a country by setting country on a channel db.MustExec(`UPDATE channels_channel SET country = 'US' WHERE id = $1`, testdata.TwilioChannel.ID) - testJSON, err := os.ReadFile("testdata/imports.json") - require.NoError(t, err) + testJSON := testsuite.ReadFile("testdata/imports.json") tcs := []struct { Description string `json:"description"` @@ -58,8 +58,7 @@ func TestContactImports(t *testing.T) { Errors json.RawMessage `json:"errors"` Contacts []*models.ContactSpec `json:"contacts"` }{} - err = jsonx.Unmarshal(testJSON, &tcs) - require.NoError(t, err) + jsonx.MustUnmarshal(testJSON, &tcs) oa, err := models.GetOrgAssets(ctx, rt, 1) require.NoError(t, err) @@ -188,8 +187,8 @@ func TestLoadContactImport(t *testing.T) { sort.Strings(batchStatuses) assert.Equal(t, []string{"C", "P"}, batchStatuses) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contactimportbatch WHERE status = 'C' AND finished_on IS NOT NULL`).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contactimportbatch WHERE status = 'P' AND finished_on IS NULL`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contactimportbatch WHERE status = 'C' AND finished_on IS NOT NULL`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contactimportbatch WHERE status = 'P' AND finished_on IS NULL`).Returns(1) } func TestContactSpecUnmarshal(t *testing.T) { diff --git a/core/models/incident.go b/core/models/incident.go new file mode 100644 index 000000000..fcd6fe644 --- /dev/null +++ b/core/models/incident.go @@ -0,0 +1,208 @@ +package models + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "time" + + "github.com/gomodule/redigo/redis" + "github.com/lib/pq" + "github.com/nyaruka/gocommon/dates" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/flows/events" + "github.com/nyaruka/mailroom/runtime" + "github.com/nyaruka/null" + "github.com/nyaruka/redisx" + "github.com/pkg/errors" +) + +// IncidentID is our type for incident ids +type IncidentID null.Int + +const NilIncidentID = IncidentID(0) + +// MarshalJSON marshals into JSON. 0 values will become null +func (i IncidentID) MarshalJSON() ([]byte, error) { + return null.Int(i).MarshalJSON() +} + +// UnmarshalJSON unmarshals from JSON. null values become 0 +func (i *IncidentID) UnmarshalJSON(b []byte) error { + return null.UnmarshalInt(b, (*null.Int)(i)) +} + +// Value returns the db value, null is returned for 0 +func (i IncidentID) Value() (driver.Value, error) { + return null.Int(i).Value() +} + +// Scan scans from the db value. null values become 0 +func (i *IncidentID) Scan(value interface{}) error { + return null.ScanInt(value, (*null.Int)(i)) +} + +type IncidentType string + +const ( + IncidentTypeOrgFlagged IncidentType = "org:flagged" + IncidentTypeWebhooksUnhealthy IncidentType = "webhooks:unhealthy" +) + +type Incident struct { + ID IncidentID `db:"id"` + OrgID OrgID `db:"org_id"` + Type IncidentType `db:"incident_type"` + Scope string `db:"scope"` + StartedOn time.Time `db:"started_on"` + EndedOn *time.Time `db:"ended_on"` + ChannelID ChannelID `db:"channel_id"` +} + +// End ends this incident +func (i *Incident) End(ctx context.Context, db Queryer) error { + now := time.Now() + i.EndedOn = &now + _, err := db.ExecContext(ctx, `UPDATE notifications_incident SET ended_on = $2 WHERE id = $1`, i.ID, i.EndedOn) + return errors.Wrap(err, "error updating incident ended_on") +} + +// IncidentWebhooksUnhealthy ensures there is an open unhealthy webhooks incident for the given org +func IncidentWebhooksUnhealthy(ctx context.Context, db Queryer, rp *redis.Pool, oa *OrgAssets, nodes []flows.NodeUUID) (IncidentID, error) { + id, err := getOrCreateIncident(ctx, db, oa, &Incident{ + OrgID: oa.OrgID(), + Type: IncidentTypeWebhooksUnhealthy, + StartedOn: dates.Now(), + Scope: "", + }) + if err != nil { + return NilIncidentID, err + } + + if len(nodes) > 0 { + rc := rp.Get() + defer rc.Close() + + nodesKey := fmt.Sprintf("incident:%d:nodes", id) + rc.Send("MULTI") + rc.Send("SADD", redis.Args{}.Add(nodesKey).AddFlat(nodes)...) + rc.Send("EXPIRE", nodesKey, 60*30) // 30 minutes + _, err = rc.Do("EXEC") + if err != nil { + return NilIncidentID, errors.Wrap(err, "error adding node uuids to incident") + } + } + + return id, nil +} + +const insertIncidentSQL = ` +INSERT INTO notifications_incident(org_id, incident_type, scope, started_on, channel_id) VALUES($1, $2, $3, $4, $5) +ON CONFLICT DO NOTHING RETURNING id` + +func getOrCreateIncident(ctx context.Context, db Queryer, oa *OrgAssets, incident *Incident) (IncidentID, error) { + var incidentID IncidentID + err := db.GetContext(ctx, &incidentID, insertIncidentSQL, incident.OrgID, incident.Type, incident.Scope, incident.StartedOn, incident.ChannelID) + if err != nil && err != sql.ErrNoRows { + return NilIncidentID, errors.Wrap(err, "error inserting incident") + } + + // if we got back an id, a new incident was actually created + if incidentID != NilIncidentID { + incident.ID = incidentID + + if err := NotifyIncidentStarted(ctx, db, oa, incident); err != nil { + return NilIncidentID, errors.Wrap(err, "error creating notifications for new incident") + } + } else { + err := db.GetContext(ctx, &incidentID, `SELECT id FROM notifications_incident WHERE org_id = $1 AND incident_type = $2 AND scope = $3`, incident.OrgID, incident.Type, incident.Scope) + if err != nil { + return NilIncidentID, errors.Wrap(err, "error looking up existing incident") + } + } + + return incidentID, nil +} + +const selectOpenIncidentsSQL = ` +SELECT id, org_id, incident_type, scope, started_on, ended_on, channel_id +FROM notifications_incident +WHERE ended_on IS NULL AND incident_type = ANY($1)` + +func GetOpenIncidents(ctx context.Context, db Queryer, types []IncidentType) ([]*Incident, error) { + rows, err := db.QueryxContext(ctx, selectOpenIncidentsSQL, pq.Array(types)) + if err != nil { + return nil, errors.Wrap(err, "error querying open incidents") + } + defer rows.Close() + + incidents := make([]*Incident, 0, 10) + for rows.Next() { + obj := &Incident{} + err := rows.StructScan(obj) + if err != nil { + return nil, errors.Wrap(err, "error scanning incident") + } + + incidents = append(incidents, obj) + } + + return incidents, nil +} + +// WebhookNode is a utility to help determine the health of an individual webhook node +type WebhookNode struct { + UUID flows.NodeUUID +} + +func (n *WebhookNode) Record(rt *runtime.Runtime, events []*events.WebhookCalledEvent) error { + numHealthy, numUnhealthy := 0, 0 + for _, e := range events { + if e.ElapsedMS <= rt.Config.WebhooksHealthyResponseLimit { + numHealthy++ + } else { + numUnhealthy++ + } + } + + rc := rt.RP.Get() + defer rc.Close() + + healthySeries, unhealthySeries := n.series() + + if numHealthy > 0 { + if err := healthySeries.Record(rc, string(n.UUID), int64(numHealthy)); err != nil { + return errors.Wrap(err, "error recording healthy calls") + } + } + if numUnhealthy > 0 { + if err := unhealthySeries.Record(rc, string(n.UUID), int64(numUnhealthy)); err != nil { + return errors.Wrap(err, "error recording unhealthy calls") + } + } + + return nil +} + +func (n *WebhookNode) Healthy(rt *runtime.Runtime) (bool, error) { + rc := rt.RP.Get() + defer rc.Close() + + healthySeries, unhealthySeries := n.series() + healthy, err := healthySeries.Total(rc, string(n.UUID)) + if err != nil { + return false, errors.Wrap(err, "error getting healthy series total") + } + unhealthy, err := unhealthySeries.Total(rc, string(n.UUID)) + if err != nil { + return false, errors.Wrap(err, "error getting healthy series total") + } + + // node is healthy if number of unhealthy calls is less than 10 or unhealthy percentage is < 25% + return unhealthy < 10 || (100*unhealthy/(healthy+unhealthy)) < 25, nil +} + +func (n *WebhookNode) series() (*redisx.IntervalSeries, *redisx.IntervalSeries) { + return redisx.NewIntervalSeries("webhooks:healthy", time.Minute*5, 4), redisx.NewIntervalSeries("webhooks:unhealthy", time.Minute*5, 4) +} diff --git a/core/models/incident_test.go b/core/models/incident_test.go new file mode 100644 index 000000000..f62bd82ca --- /dev/null +++ b/core/models/incident_test.go @@ -0,0 +1,141 @@ +package models_test + +import ( + "fmt" + "net/http" + "sort" + "testing" + "time" + + "github.com/nyaruka/gocommon/dates" + "github.com/nyaruka/gocommon/dbutil/assertdb" + "github.com/nyaruka/gocommon/httpx" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/flows/events" + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/testsuite" + "github.com/nyaruka/mailroom/testsuite/testdata" + "github.com/nyaruka/redisx/assertredis" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIncidentWebhooksUnhealthy(t *testing.T) { + ctx, rt, db, rp := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetData) + + oa := testdata.Org1.Load(rt) + + id1, err := models.IncidentWebhooksUnhealthy(ctx, db, rp, oa, []flows.NodeUUID{"5a2e83f1-efa8-40ba-bc0c-8873c525de7d", "aba89043-6f0a-4ccf-ba7f-0e1674b90759"}) + require.NoError(t, err) + assert.NotEqual(t, 0, id1) + + assertdb.Query(t, db, `SELECT count(*) FROM notifications_incident`).Returns(1) + assertredis.SMembers(t, rp, fmt.Sprintf("incident:%d:nodes", id1), []string{"5a2e83f1-efa8-40ba-bc0c-8873c525de7d", "aba89043-6f0a-4ccf-ba7f-0e1674b90759"}) + + // raising same incident doesn't create a new one... + id2, err := models.IncidentWebhooksUnhealthy(ctx, db, rp, oa, []flows.NodeUUID{"3b1743cd-bd8b-449e-8e8a-11a3bc479766"}) + require.NoError(t, err) + assert.Equal(t, id1, id2) + + // but will add new nodes to the incident's node set + assertdb.Query(t, db, `SELECT count(*) FROM notifications_incident`).Returns(1) + assertredis.SMembers(t, rp, fmt.Sprintf("incident:%d:nodes", id1), []string{"3b1743cd-bd8b-449e-8e8a-11a3bc479766", "5a2e83f1-efa8-40ba-bc0c-8873c525de7d", "aba89043-6f0a-4ccf-ba7f-0e1674b90759"}) + + // when the incident has ended, a new one can be created + db.MustExec(`UPDATE notifications_incident SET ended_on = NOW()`) + + id3, err := models.IncidentWebhooksUnhealthy(ctx, db, rp, oa, nil) + require.NoError(t, err) + assert.NotEqual(t, id1, id3) + + assertdb.Query(t, db, `SELECT count(*) FROM notifications_incident`).Returns(2) + +} + +func TestGetOpenIncidents(t *testing.T) { + ctx, rt, db, rp := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetData) + + oa1 := testdata.Org1.Load(rt) + oa2 := testdata.Org2.Load(rt) + + // create incident for org 1 + id1, err := models.IncidentWebhooksUnhealthy(ctx, db, rp, oa1, nil) + require.NoError(t, err) + + incidents, err := models.GetOpenIncidents(ctx, db, []models.IncidentType{models.IncidentTypeWebhooksUnhealthy}) + assert.NoError(t, err) + assert.Equal(t, 1, len(incidents)) + assert.Equal(t, id1, incidents[0].ID) + assert.Equal(t, models.IncidentTypeWebhooksUnhealthy, incidents[0].Type) + + // but then end it + err = incidents[0].End(ctx, db) + require.NoError(t, err) + + // and create another one... + id2, err := models.IncidentWebhooksUnhealthy(ctx, db, rp, oa1, nil) + require.NoError(t, err) + + // create an incident for org 2 + id3, err := models.IncidentWebhooksUnhealthy(ctx, db, rp, oa2, nil) + require.NoError(t, err) + + incidents, err = models.GetOpenIncidents(ctx, db, []models.IncidentType{models.IncidentTypeWebhooksUnhealthy}) + require.NoError(t, err) + + assert.Equal(t, 2, len(incidents)) + + sort.Slice(incidents, func(i, j int) bool { return incidents[i].ID < incidents[j].ID }) // db results aren't ordered + + assert.Equal(t, id2, incidents[0].ID) + assert.Equal(t, id3, incidents[1].ID) +} + +func TestWebhookNode(t *testing.T) { + _, rt, _, _ := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetRedis) + + node := &models.WebhookNode{UUID: "3c703019-8c92-4d28-9be0-a926a934486b"} + healthy, err := node.Healthy(rt) + assert.NoError(t, err) + assert.True(t, healthy) + + createWebhookEvents := func(count int, elapsed time.Duration) []*events.WebhookCalledEvent { + evts := make([]*events.WebhookCalledEvent, count) + for i := range evts { + req, _ := http.NewRequest("GET", "http://example.com", nil) + trace := &httpx.Trace{Request: req, StartTime: dates.Now(), EndTime: dates.Now().Add(elapsed)} + evts[i] = events.NewWebhookCalled(&flows.WebhookCall{Trace: trace}, flows.CallStatusSuccess, "") + } + return evts + } + + // record 10 healthy calls + err = node.Record(rt, createWebhookEvents(10, time.Second*1)) + assert.NoError(t, err) + + healthy, err = node.Healthy(rt) + assert.NoError(t, err) + assert.True(t, healthy) + + // record 5 unhealthy calls + err = node.Record(rt, createWebhookEvents(5, time.Second*30)) + assert.NoError(t, err) + + healthy, err = node.Healthy(rt) + assert.NoError(t, err) + assert.True(t, healthy) + + // record another 5 unhealthy calls + err = node.Record(rt, createWebhookEvents(5, time.Second*30)) + assert.NoError(t, err) + + healthy, err = node.Healthy(rt) + assert.NoError(t, err) + assert.False(t, healthy) +} diff --git a/core/models/labels.go b/core/models/labels.go index d301f149c..9a1e45054 100644 --- a/core/models/labels.go +++ b/core/models/labels.go @@ -4,8 +4,8 @@ import ( "context" "time" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/goflow/assets" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/jmoiron/sqlx" "github.com/pkg/errors" @@ -45,7 +45,7 @@ func loadLabels(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets.Lab labels := make([]assets.Label, 0, 10) for rows.Next() { label := &Label{} - err = dbutil.ReadJSONRow(rows, &label.l) + err = dbutil.ScanJSON(rows, &label.l) if err != nil { return nil, errors.Wrap(err, "error scanning label row") } diff --git a/core/models/msgs.go b/core/models/msgs.go index b4c84295d..f95e104f0 100644 --- a/core/models/msgs.go +++ b/core/models/msgs.go @@ -62,12 +62,6 @@ const ( type MsgStatus string -// BroadcastID is our internal type for broadcast ids, which can be null/0 -type BroadcastID null.Int - -// NilBroadcastID is our constant for a nil broadcast id -const NilBroadcastID = BroadcastID(0) - const ( MsgStatusInitializing = MsgStatus("I") MsgStatusPending = MsgStatus("P") @@ -81,6 +75,23 @@ const ( MsgStatusResent = MsgStatus("R") ) +type MsgFailedReason null.String + +const ( + NilMsgFailedReason = MsgFailedReason("") + MsgFailedSuspended = MsgFailedReason("S") + MsgFailedLooping = MsgFailedReason("L") + MsgFailedErrorLimit = MsgFailedReason("E") + MsgFailedTooOld = MsgFailedReason("O") + MsgFailedNoDestination = MsgFailedReason("D") +) + +// BroadcastID is our internal type for broadcast ids, which can be null/0 +type BroadcastID null.Int + +// NilBroadcastID is our constant for a nil broadcast id +const NilBroadcastID = BroadcastID(0) + // TemplateState represents what state are templates are in, either already evaluated, not evaluated or // that they are unevaluated legacy templates type TemplateState string @@ -105,29 +116,31 @@ type Msg struct { QueuedOn time.Time `db:"queued_on" json:"queued_on"` Direction MsgDirection `db:"direction" json:"direction"` Status MsgStatus `db:"status" json:"status"` - Visibility MsgVisibility `db:"visibility" json:"visibility"` - MsgType MsgType `db:"msg_type"` + Visibility MsgVisibility `db:"visibility" json:"-"` + MsgType MsgType `db:"msg_type" json:"-"` MsgCount int `db:"msg_count" json:"tps_cost"` ErrorCount int `db:"error_count" json:"error_count"` NextAttempt *time.Time `db:"next_attempt" json:"next_attempt"` - ExternalID null.String `db:"external_id" json:"external_id"` - Attachments pq.StringArray `db:"attachments" json:"attachments"` + FailedReason MsgFailedReason `db:"failed_reason" json:"-"` + ExternalID null.String `db:"external_id" json:"-"` + ResponseToExternalID null.String ` json:"response_to_external_id,omitempty"` + Attachments pq.StringArray `db:"attachments" json:"attachments,omitempty"` Metadata null.Map `db:"metadata" json:"metadata,omitempty"` ChannelID ChannelID `db:"channel_id" json:"channel_id"` ChannelUUID assets.ChannelUUID ` json:"channel_uuid"` - ConnectionID *ConnectionID `db:"connection_id"` ContactID ContactID `db:"contact_id" json:"contact_id"` ContactURNID *URNID `db:"contact_urn_id" json:"contact_urn_id"` - ResponseToID MsgID `db:"response_to_id" json:"response_to_id"` - ResponseToExternalID null.String ` json:"response_to_external_id"` IsResend bool ` json:"is_resend,omitempty"` - URN urns.URN ` json:"urn"` - URNAuth null.String ` json:"urn_auth,omitempty"` + URN urns.URN `db:"urn_urn" json:"urn"` + URNAuth null.String `db:"urn_auth" json:"urn_auth,omitempty"` OrgID OrgID `db:"org_id" json:"org_id"` - TopupID TopupID `db:"topup_id"` + TopupID TopupID `db:"topup_id" json:"-"` + FlowID FlowID `db:"flow_id" json:"-"` - SessionID SessionID `json:"session_id,omitempty"` - SessionStatus SessionStatus `json:"session_status,omitempty"` + // extra data from handling added to the courier payload + SessionID SessionID `json:"session_id,omitempty"` + SessionStatus SessionStatus `json:"session_status,omitempty"` + Flow *assets.FlowReference `json:"flow,omitempty"` // These fields are set on the last outgoing message in a session's sprint. In the case // of the session being at a wait with a timeout then the timeout will be set. It is up to @@ -155,23 +168,33 @@ func (m *Msg) Visibility() MsgVisibility { return m.m.Visibility } func (m *Msg) MsgType() MsgType { return m.m.MsgType } func (m *Msg) ErrorCount() int { return m.m.ErrorCount } func (m *Msg) NextAttempt() *time.Time { return m.m.NextAttempt } +func (m *Msg) FailedReason() MsgFailedReason { return m.m.FailedReason } func (m *Msg) ExternalID() null.String { return m.m.ExternalID } func (m *Msg) Metadata() map[string]interface{} { return m.m.Metadata.Map() } func (m *Msg) MsgCount() int { return m.m.MsgCount } func (m *Msg) ChannelID() ChannelID { return m.m.ChannelID } func (m *Msg) ChannelUUID() assets.ChannelUUID { return m.m.ChannelUUID } -func (m *Msg) ConnectionID() *ConnectionID { return m.m.ConnectionID } func (m *Msg) URN() urns.URN { return m.m.URN } func (m *Msg) URNAuth() null.String { return m.m.URNAuth } func (m *Msg) OrgID() OrgID { return m.m.OrgID } func (m *Msg) TopupID() TopupID { return m.m.TopupID } +func (m *Msg) FlowID() FlowID { return m.m.FlowID } func (m *Msg) ContactID() ContactID { return m.m.ContactID } func (m *Msg) ContactURNID() *URNID { return m.m.ContactURNID } func (m *Msg) IsResend() bool { return m.m.IsResend } -func (m *Msg) SetTopup(topupID TopupID) { m.m.TopupID = topupID } -func (m *Msg) SetChannelID(channelID ChannelID) { m.m.ChannelID = channelID } -func (m *Msg) SetBroadcastID(broadcastID BroadcastID) { m.m.BroadcastID = broadcastID } +func (m *Msg) SetTopup(topupID TopupID) { m.m.TopupID = topupID } + +func (m *Msg) SetChannel(channel *Channel) { + m.channel = channel + if channel != nil { + m.m.ChannelID = channel.ID() + m.m.ChannelUUID = channel.UUID() + } else { + m.m.ChannelID = NilChannelID + m.m.ChannelUUID = "" + } +} func (m *Msg) SetURN(urn urns.URN) error { // noop for nil urn @@ -202,16 +225,6 @@ func (m *Msg) Attachments() []utils.Attachment { return attachments } -// SetResponseTo set the incoming message that this session should be associated with in this sprint -func (m *Msg) SetResponseTo(id MsgID, externalID null.String) { - m.m.ResponseToID = id - m.m.ResponseToExternalID = externalID - - if id != NilMsgID || externalID != "" { - m.m.HighPriority = true - } -} - func (m *Msg) MarshalJSON() ([]byte, error) { return json.Marshal(m.m) } @@ -232,16 +245,12 @@ func NewIncomingIVR(cfg *runtime.Config, orgID OrgID, conn *ChannelConnection, i urnID := conn.ContactURNID() m.ContactURNID = &urnID - - connID := conn.ID() - m.ConnectionID = &connID + m.ChannelID = conn.ChannelID() m.OrgID = orgID m.TopupID = NilTopupID m.CreatedOn = createdOn - msg.SetChannelID(conn.ChannelID()) - // add any attachments for _, a := range in.Attachments() { m.Attachments = append(m.Attachments, string(NormalizeAttachment(cfg, a))) @@ -267,9 +276,7 @@ func NewOutgoingIVR(cfg *runtime.Config, orgID OrgID, conn *ChannelConnection, o urnID := conn.ContactURNID() m.ContactURNID = &urnID - - connID := conn.ID() - m.ConnectionID = &connID + m.ChannelID = conn.ChannelID() m.URN = out.URN() @@ -277,7 +284,6 @@ func NewOutgoingIVR(cfg *runtime.Config, orgID OrgID, conn *ChannelConnection, o m.TopupID = NilTopupID m.CreatedOn = createdOn m.SentOn = &createdOn - msg.SetChannelID(conn.ChannelID()) // if we have attachments, add them for _, a := range out.Attachments() { @@ -287,46 +293,116 @@ func NewOutgoingIVR(cfg *runtime.Config, orgID OrgID, conn *ChannelConnection, o return msg } -// NewOutgoingMsg creates an outgoing message for the passed in flow message. -func NewOutgoingMsg(cfg *runtime.Config, org *Org, channel *Channel, contactID ContactID, out *flows.MsgOut, createdOn time.Time) (*Msg, error) { - msg := &Msg{} - m := &msg.m +var msgRepetitionsScript = redis.NewScript(3, ` +local key, contact_id, text = KEYS[1], KEYS[2], KEYS[3] +local count = 1 + +-- try to look up in window +local record = redis.call("HGET", key, contact_id) +if record then + local record_count = tonumber(string.sub(record, 1, 2)) + local record_text = string.sub(record, 4, -1) + + if record_text == text then + count = math.min(record_count + 1, 99) + else + count = 1 + end +end + +-- create our new record with our updated count +record = string.format("%02d:%s", count, text) + +-- write our new record with updated count and set expiration +redis.call("HSET", key, contact_id, record) +redis.call("EXPIRE", key, 300) + +return count +`) + +// GetMsgRepetitions gets the number of repetitions of this msg text for the given contact in the current 5 minute window +func GetMsgRepetitions(rp *redis.Pool, contactID ContactID, msg *flows.MsgOut) (int, error) { + rc := rp.Get() + defer rc.Close() + + keyTime := dates.Now().UTC().Round(time.Minute * 5) + key := fmt.Sprintf("msg_repetitions:%s", keyTime.Format("2006-01-02T15:04")) + return redis.Int(msgRepetitionsScript.Do(rc, key, contactID, msg.Text())) +} - // we fail messages for suspended orgs right away - status := MsgStatusQueued - if org.Suspended() { - status = MsgStatusFailed - } +// NewOutgoingFlowMsg creates an outgoing message for the passed in flow message +func NewOutgoingFlowMsg(rt *runtime.Runtime, org *Org, channel *Channel, session *Session, flow *Flow, out *flows.MsgOut, createdOn time.Time) (*Msg, error) { + return newOutgoingMsg(rt, org, channel, session.ContactID(), out, createdOn, session, flow, NilBroadcastID) +} + +// NewOutgoingBroadcastMsg creates an outgoing message which is part of a broadcast +func NewOutgoingBroadcastMsg(rt *runtime.Runtime, org *Org, channel *Channel, contactID ContactID, out *flows.MsgOut, createdOn time.Time, broadcastID BroadcastID) (*Msg, error) { + return newOutgoingMsg(rt, org, channel, contactID, out, createdOn, nil, nil, broadcastID) +} +func newOutgoingMsg(rt *runtime.Runtime, org *Org, channel *Channel, contactID ContactID, out *flows.MsgOut, createdOn time.Time, session *Session, flow *Flow, broadcastID BroadcastID) (*Msg, error) { + msg := &Msg{} + m := &msg.m m.UUID = out.UUID() + m.OrgID = org.ID() + m.ContactID = contactID + m.BroadcastID = broadcastID + m.TopupID = NilTopupID m.Text = out.Text() m.HighPriority = false m.Direction = DirectionOut - m.Status = status + m.Status = MsgStatusQueued m.Visibility = VisibilityVisible m.MsgType = MsgTypeFlow - m.ContactID = contactID - m.OrgID = org.ID() - m.TopupID = NilTopupID + m.MsgCount = 1 m.CreatedOn = createdOn - err := msg.SetURN(out.URN()) - if err != nil { - return nil, errors.Wrapf(err, "error setting msg urn") - } + msg.SetChannel(channel) + msg.SetURN(out.URN()) - if channel != nil { - m.ChannelUUID = channel.UUID() - msg.SetChannelID(channel.ID()) - msg.channel = channel + if org.Suspended() { + // we fail messages for suspended orgs right away + m.Status = MsgStatusFailed + m.FailedReason = MsgFailedSuspended + } else if msg.URN() == urns.NilURN || channel == nil { + // if msg is missing the URN or channel, we also fail it + m.Status = MsgStatusFailed + m.FailedReason = MsgFailedNoDestination + } else { + // also fail right away if this looks like a loop + repetitions, err := GetMsgRepetitions(rt.RP, contactID, out) + if err != nil { + return nil, errors.Wrap(err, "error looking up msg repetitions") + } + if repetitions >= 20 { + m.Status = MsgStatusFailed + m.FailedReason = MsgFailedLooping + + logrus.WithFields(logrus.Fields{"contact_id": contactID, "text": out.Text(), "repetitions": repetitions}).Error("too many repetitions, failing message") + } } - m.MsgCount = 1 + // if we have a session, set fields on the message from that + if session != nil { + m.ResponseToExternalID = session.IncomingMsgExternalID() + m.SessionID = session.ID() + m.SessionStatus = session.Status() + + if flow != nil { + m.FlowID = flow.ID() + m.Flow = flow.Reference() + } + + // if we're responding to an incoming message, send as high priority + if session.IncomingMsgID() != NilMsgID { + m.HighPriority = true + } + } // if we have attachments, add them if len(out.Attachments()) > 0 { for _, a := range out.Attachments() { - m.Attachments = append(m.Attachments, string(NormalizeAttachment(cfg, a))) + m.Attachments = append(m.Attachments, string(NormalizeAttachment(rt.Config, a))) } } @@ -345,11 +421,9 @@ func NewOutgoingMsg(cfg *runtime.Config, org *Org, channel *Channel, contactID C m.Metadata = null.NewMap(metadata) } - // calculate msg count + // if we're sending to a phone, message may have to be sent in multiple parts if m.URN.Scheme() == urns.TelScheme { m.MsgCount = gsm7.Segments(m.Text) + len(m.Attachments) - } else { - m.MsgCount = 1 } return msg, nil @@ -358,9 +432,11 @@ func NewOutgoingMsg(cfg *runtime.Config, org *Org, channel *Channel, contactID C // NewIncomingMsg creates a new incoming message for the passed in text and attachment func NewIncomingMsg(cfg *runtime.Config, orgID OrgID, channel *Channel, contactID ContactID, in *flows.MsgIn, createdOn time.Time) *Msg { msg := &Msg{} - m := &msg.m + msg.SetChannel(channel) msg.SetURN(in.URN()) + + m := &msg.m m.UUID = in.UUID() m.Text = in.Text() m.Direction = DirectionIn @@ -368,17 +444,10 @@ func NewIncomingMsg(cfg *runtime.Config, orgID OrgID, channel *Channel, contactI m.Visibility = VisibilityVisible m.MsgType = MsgTypeFlow m.ContactID = contactID - m.OrgID = orgID m.TopupID = NilTopupID m.CreatedOn = createdOn - if channel != nil { - msg.SetChannelID(channel.ID()) - m.ChannelUUID = channel.UUID() - msg.channel = channel - } - // add any attachments for _, a := range in.Attachments() { m.Attachments = append(m.Attachments, string(NormalizeAttachment(cfg, a))) @@ -400,14 +469,14 @@ SELECT msg_count, error_count, next_attempt, + failed_reason, + coalesce(high_priority, FALSE) as high_priority, external_id, attachments, metadata, channel_id, - connection_id, contact_id, contact_urn_id, - response_to_id, org_id, topup_id FROM @@ -419,15 +488,63 @@ WHERE ORDER BY id ASC` -// LoadMessages loads the given messages for the passed in org -func LoadMessages(ctx context.Context, db Queryer, orgID OrgID, direction MsgDirection, msgIDs []MsgID) ([]*Msg, error) { - rows, err := db.QueryxContext(ctx, loadMessagesSQL, orgID, direction, pq.Array(msgIDs)) +// GetMessagesByID fetches the messages with the given ids +func GetMessagesByID(ctx context.Context, db Queryer, orgID OrgID, direction MsgDirection, msgIDs []MsgID) ([]*Msg, error) { + return loadMessages(ctx, db, loadMessagesSQL, orgID, direction, pq.Array(msgIDs)) +} + +var loadMessagesForRetrySQL = ` +SELECT + m.id, + m.broadcast_id, + m.uuid, + m.text, + m.created_on, + m.direction, + m.status, + m.visibility, + m.msg_count, + m.error_count, + m.next_attempt, + m.failed_reason, + m.high_priority, + m.external_id, + m.attachments, + m.metadata, + m.channel_id, + m.contact_id, + m.contact_urn_id, + m.org_id, + m.topup_id, + u.identity AS "urn_urn", + u.auth AS "urn_auth" +FROM + msgs_msg m +INNER JOIN + contacts_contacturn u ON u.id = m.contact_urn_id +WHERE + m.direction = 'O' AND + m.status = 'E' AND + m.next_attempt <= NOW() +ORDER BY + m.next_attempt ASC, m.created_on ASC +LIMIT 5000` + +func GetMessagesForRetry(ctx context.Context, db Queryer) ([]*Msg, error) { + return loadMessages(ctx, db, loadMessagesForRetrySQL) +} + +func loadMessages(ctx context.Context, db Queryer, sql string, params ...interface{}) ([]*Msg, error) { + rows, err := db.QueryxContext(ctx, sql, params...) if err != nil { - return nil, errors.Wrapf(err, "error querying msgs for org: %d", orgID) + return nil, errors.Wrapf(err, "error querying msgs") } defer rows.Close() msgs := make([]*Msg, 0) + channelIDsSeen := make(map[ChannelID]bool) + channelIDs := make([]ChannelID, 0, 5) + for rows.Next() { msg := &Msg{} err = rows.StructScan(&msg.m) @@ -436,6 +553,25 @@ func LoadMessages(ctx context.Context, db Queryer, orgID OrgID, direction MsgDir } msgs = append(msgs, msg) + + if msg.ChannelID() != NilChannelID && !channelIDsSeen[msg.ChannelID()] { + channelIDsSeen[msg.ChannelID()] = true + channelIDs = append(channelIDs, msg.ChannelID()) + } + } + + channels, err := GetChannelsByID(ctx, db, channelIDs) + if err != nil { + return nil, errors.Wrap(err, "error fetching channels for messages") + } + + channelsByID := make(map[ChannelID]*Channel) + for _, ch := range channels { + channelsByID[ch.ID()] = ch + } + + for _, msg := range msgs { + msg.SetChannel(channelsByID[msg.m.ChannelID]) } return msgs, nil @@ -460,11 +596,6 @@ func NormalizeAttachment(cfg *runtime.Config, attachment utils.Attachment) utils return utils.Attachment(fmt.Sprintf("%s:%s", attachment.ContentType(), url)) } -func (m *Msg) SetSession(id SessionID, status SessionStatus) { - m.m.SessionID = id - m.m.SessionStatus = status -} - // SetTimeout sets the timeout for this message func (m *Msg) SetTimeout(start time.Time, timeout time.Duration) { m.m.SessionWaitStartedOn = &start @@ -484,31 +615,31 @@ func InsertMessages(ctx context.Context, tx Queryer, msgs []*Msg) error { const insertMsgSQL = ` INSERT INTO msgs_msg(uuid, text, high_priority, created_on, modified_on, queued_on, sent_on, direction, status, attachments, metadata, - visibility, msg_type, msg_count, error_count, next_attempt, channel_id, connection_id, response_to_id, - contact_id, contact_urn_id, org_id, topup_id, broadcast_id) + visibility, msg_type, msg_count, error_count, next_attempt, failed_reason, channel_id, + contact_id, contact_urn_id, org_id, topup_id, flow_id, broadcast_id) VALUES(:uuid, :text, :high_priority, :created_on, now(), now(), :sent_on, :direction, :status, :attachments, :metadata, - :visibility, :msg_type, :msg_count, :error_count, :next_attempt, :channel_id, :connection_id, :response_to_id, - :contact_id, :contact_urn_id, :org_id, :topup_id, :broadcast_id) + :visibility, :msg_type, :msg_count, :error_count, :next_attempt, :failed_reason, :channel_id, + :contact_id, :contact_urn_id, :org_id, :topup_id, :flow_id, :broadcast_id) RETURNING id as id, now() as modified_on, now() as queued_on ` -// UpdateMessage updates the passed in message status, visibility and msg type -func UpdateMessage(ctx context.Context, tx Queryer, msgID flows.MsgID, status MsgStatus, visibility MsgVisibility, msgType MsgType, topup TopupID) error { - _, err := tx.ExecContext( - ctx, +// UpdateMessage updates a message after handling +func UpdateMessage(ctx context.Context, tx Queryer, msgID flows.MsgID, status MsgStatus, visibility MsgVisibility, msgType MsgType, flow FlowID, topup TopupID) error { + _, err := tx.ExecContext(ctx, `UPDATE msgs_msg SET status = $2, visibility = $3, msg_type = $4, - topup_id = $5 + flow_id = $5, + topup_id = $6 WHERE id = $1`, - msgID, status, visibility, msgType, topup) + msgID, status, visibility, msgType, flow, topup) if err != nil { return errors.Wrapf(err, "error updating msg: %d", msgID) @@ -517,11 +648,16 @@ func UpdateMessage(ctx context.Context, tx Queryer, msgID flows.MsgID, status Ms return nil } -// MarkMessagesPending marks the passed in messages as pending +// MarkMessagesPending marks the passed in messages as pending(P) func MarkMessagesPending(ctx context.Context, db Queryer, msgs []*Msg) error { return updateMessageStatus(ctx, db, msgs, MsgStatusPending) } +// MarkMessagesQueued marks the passed in messages as queued(Q) +func MarkMessagesQueued(ctx context.Context, db Queryer, msgs []*Msg) error { + return updateMessageStatus(ctx, db, msgs, MsgStatusQueued) +} + func updateMessageStatus(ctx context.Context, db Queryer, msgs []*Msg, status MsgStatus) error { is := make([]interface{}, len(msgs)) for i, msg := range msgs { @@ -546,16 +682,6 @@ WHERE msgs_msg.id = m.id::bigint ` -// GetMessageIDFromUUID gets the ID of a message from its UUID -func GetMessageIDFromUUID(ctx context.Context, db Queryer, uuid flows.MsgUUID) (MsgID, error) { - var id MsgID - err := db.GetContext(ctx, &id, `SELECT id FROM msgs_msg WHERE uuid = $1`, uuid) - if err != nil { - return NilMsgID, errors.Wrapf(err, "error querying id for msg with uuid '%s'", uuid) - } - return id, nil -} - // BroadcastTranslation is the translation for the passed in language type BroadcastTranslation struct { Text string `json:"text"` @@ -737,7 +863,7 @@ INSERT INTO ` // NewBroadcastFromEvent creates a broadcast object from the passed in broadcast event -func NewBroadcastFromEvent(ctx context.Context, tx Queryer, org *OrgAssets, event *events.BroadcastCreatedEvent) (*Broadcast, error) { +func NewBroadcastFromEvent(ctx context.Context, tx Queryer, oa *OrgAssets, event *events.BroadcastCreatedEvent) (*Broadcast, error) { // converst our translations to our type translations := make(map[envs.Language]*BroadcastTranslation) for l, t := range event.Translations { @@ -749,7 +875,7 @@ func NewBroadcastFromEvent(ctx context.Context, tx Queryer, org *OrgAssets, even } // resolve our contact references - contactIDs, err := GetContactIDsFromReferences(ctx, tx, org.OrgID(), event.Contacts) + contactIDs, err := GetContactIDsFromReferences(ctx, tx, oa.OrgID(), event.Contacts) if err != nil { return nil, errors.Wrapf(err, "error resolving contact references") } @@ -757,13 +883,13 @@ func NewBroadcastFromEvent(ctx context.Context, tx Queryer, org *OrgAssets, even // and our groups groupIDs := make([]GroupID, 0, len(event.Groups)) for i := range event.Groups { - group := org.GroupByUUID(event.Groups[i].UUID) + group := oa.GroupByUUID(event.Groups[i].UUID) if group != nil { groupIDs = append(groupIDs, group.ID()) } } - return NewBroadcast(org.OrgID(), NilBroadcastID, translations, TemplateStateEvaluated, event.BaseLanguage, event.URNs, contactIDs, groupIDs, NilTicketID), nil + return NewBroadcast(oa.OrgID(), NilBroadcastID, translations, TemplateStateEvaluated, event.BaseLanguage, event.URNs, contactIDs, groupIDs, NilTicketID), nil } func (b *Broadcast) CreateBatch(contactIDs []ContactID) *BroadcastBatch { @@ -956,8 +1082,7 @@ func CreateBroadcastMessages(ctx context.Context, rt *runtime.Runtime, oa *OrgAs // create our outgoing message out := flows.NewMsgOut(urn, channel.ChannelReference(), text, t.Attachments, t.QuickReplies, nil, flows.NilMsgTopic) - msg, err := NewOutgoingMsg(rt.Config, oa.Org(), channel, c.ID(), out, time.Now()) - msg.SetBroadcastID(bcast.BroadcastID()) + msg, err := NewOutgoingBroadcastMsg(rt, oa.Org(), channel, c.ID(), out, time.Now(), bcast.BroadcastID()) if err != nil { return nil, errors.Wrapf(err, "error creating outgoing message") } @@ -1029,6 +1154,7 @@ const updateMsgForResendingSQL = ` topup_id = r.topup_id::int, status = 'P', error_count = 0, + failed_reason = NULL, queued_on = r.queued_on::timestamp with time zone, sent_on = NULL, modified_on = NOW() @@ -1081,6 +1207,7 @@ func ResendMessages(ctx context.Context, db Queryer, rp *redis.Pool, oa *OrgAsse msg.m.QueuedOn = dates.Now() msg.m.SentOn = nil msg.m.ErrorCount = 0 + msg.m.FailedReason = "" msg.m.IsResend = true resends[i] = msg.m @@ -1150,3 +1277,13 @@ func (i BroadcastID) Value() (driver.Value, error) { func (i *BroadcastID) Scan(value interface{}) error { return null.ScanInt(value, (*null.Int)(i)) } + +// Value returns the db value, null is returned for "" +func (s MsgFailedReason) Value() (driver.Value, error) { + return null.String(s).Value() +} + +// Scan scans from the db value. null values become "" +func (s *MsgFailedReason) Scan(value interface{}) error { + return null.ScanString(value, (*null.String)(s)) +} diff --git a/core/models/msgs_test.go b/core/models/msgs_test.go index c58728589..07e94e0ca 100644 --- a/core/models/msgs_test.go +++ b/core/models/msgs_test.go @@ -1,164 +1,397 @@ package models_test import ( + "context" + "encoding/json" "fmt" "testing" "time" "github.com/nyaruka/gocommon/dates" + "github.com/nyaruka/gocommon/dbutil/assertdb" + "github.com/nyaruka/gocommon/jsonx" "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/envs" "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/test" "github.com/nyaruka/goflow/utils" "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/runtime" "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" + "github.com/nyaruka/null" + "github.com/nyaruka/redisx/assertredis" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestOutgoingMsgs(t *testing.T) { +func TestNewOutgoingFlowMsg(t *testing.T) { ctx, rt, db, _ := testsuite.Get() + defer testsuite.Reset(testsuite.ResetData) + tcs := []struct { ChannelUUID assets.ChannelUUID Text string - ContactID models.ContactID + Contact *testdata.Contact URN urns.URN URNID models.URNID Attachments []utils.Attachment QuickReplies []string Topic flows.MsgTopic + Flow *testdata.Flow + ResponseTo models.MsgID SuspendedOrg bool - ExpectedStatus models.MsgStatus - ExpectedMetadata map[string]interface{} - ExpectedMsgCount int - HasError bool + ExpectedStatus models.MsgStatus + ExpectedFailedReason models.MsgFailedReason + ExpectedMetadata map[string]interface{} + ExpectedMsgCount int + ExpectedPriority bool }{ { - ChannelUUID: "74729f45-7f29-4868-9dc4-90e491e3c7d8", - Text: "missing urn id", - ContactID: testdata.Cathy.ID, - URN: urns.URN("tel:+250700000001"), - URNID: models.URNID(0), - ExpectedStatus: models.MsgStatusQueued, - ExpectedMetadata: map[string]interface{}{}, - ExpectedMsgCount: 1, - HasError: true, + ChannelUUID: "74729f45-7f29-4868-9dc4-90e491e3c7d8", + Text: "missing urn id", + Contact: testdata.Cathy, + URN: urns.URN("tel:+250700000001"), + URNID: models.URNID(0), + Flow: testdata.Favorites, + ResponseTo: models.MsgID(123425), + ExpectedStatus: models.MsgStatusQueued, + ExpectedFailedReason: models.NilMsgFailedReason, + ExpectedMetadata: map[string]interface{}{}, + ExpectedMsgCount: 1, + ExpectedPriority: true, }, { - ChannelUUID: "74729f45-7f29-4868-9dc4-90e491e3c7d8", - Text: "test outgoing", - ContactID: testdata.Cathy.ID, - URN: urns.URN(fmt.Sprintf("tel:+250700000001?id=%d", testdata.Cathy.URNID)), - URNID: testdata.Cathy.URNID, - QuickReplies: []string{"yes", "no"}, - Topic: flows.MsgTopicPurchase, - ExpectedStatus: models.MsgStatusQueued, + ChannelUUID: "74729f45-7f29-4868-9dc4-90e491e3c7d8", + Text: "test outgoing", + Contact: testdata.Cathy, + URN: urns.URN(fmt.Sprintf("tel:+250700000001?id=%d", testdata.Cathy.URNID)), + URNID: testdata.Cathy.URNID, + QuickReplies: []string{"yes", "no"}, + Topic: flows.MsgTopicPurchase, + Flow: testdata.SingleMessage, + ExpectedStatus: models.MsgStatusQueued, + ExpectedFailedReason: models.NilMsgFailedReason, ExpectedMetadata: map[string]interface{}{ "quick_replies": []string{"yes", "no"}, "topic": "purchase", }, ExpectedMsgCount: 1, + ExpectedPriority: false, }, { - ChannelUUID: "74729f45-7f29-4868-9dc4-90e491e3c7d8", - Text: "test outgoing", - ContactID: testdata.Cathy.ID, - URN: urns.URN(fmt.Sprintf("tel:+250700000001?id=%d", testdata.Cathy.URNID)), - URNID: testdata.Cathy.URNID, - Attachments: []utils.Attachment{utils.Attachment("image/jpeg:https://dl-foo.com/image.jpg")}, - ExpectedStatus: models.MsgStatusQueued, - ExpectedMetadata: map[string]interface{}{}, - ExpectedMsgCount: 2, + ChannelUUID: "74729f45-7f29-4868-9dc4-90e491e3c7d8", + Text: "test outgoing", + Contact: testdata.Cathy, + URN: urns.URN(fmt.Sprintf("tel:+250700000001?id=%d", testdata.Cathy.URNID)), + URNID: testdata.Cathy.URNID, + Attachments: []utils.Attachment{utils.Attachment("image/jpeg:https://dl-foo.com/image.jpg")}, + Flow: testdata.Favorites, + ExpectedStatus: models.MsgStatusQueued, + ExpectedFailedReason: models.NilMsgFailedReason, + ExpectedMetadata: map[string]interface{}{}, + ExpectedMsgCount: 2, + ExpectedPriority: false, }, { - ChannelUUID: "74729f45-7f29-4868-9dc4-90e491e3c7d8", - Text: "suspended org", - ContactID: testdata.Cathy.ID, - URN: urns.URN(fmt.Sprintf("tel:+250700000001?id=%d", testdata.Cathy.URNID)), - URNID: testdata.Cathy.URNID, - SuspendedOrg: true, - ExpectedStatus: models.MsgStatusFailed, - ExpectedMetadata: map[string]interface{}{}, - ExpectedMsgCount: 1, + ChannelUUID: "74729f45-7f29-4868-9dc4-90e491e3c7d8", + Text: "suspended org", + Contact: testdata.Cathy, + URN: urns.URN(fmt.Sprintf("tel:+250700000001?id=%d", testdata.Cathy.URNID)), + URNID: testdata.Cathy.URNID, + Flow: testdata.Favorites, + SuspendedOrg: true, + ExpectedStatus: models.MsgStatusFailed, + ExpectedFailedReason: models.MsgFailedSuspended, + ExpectedMetadata: map[string]interface{}{}, + ExpectedMsgCount: 1, + ExpectedPriority: false, + }, + { + ChannelUUID: "74729f45-7f29-4868-9dc4-90e491e3c7d8", + Text: "missing URN", + Contact: testdata.Cathy, + URN: urns.NilURN, + URNID: models.URNID(0), + Flow: testdata.Favorites, + SuspendedOrg: false, + ExpectedStatus: models.MsgStatusFailed, + ExpectedFailedReason: models.MsgFailedNoDestination, + ExpectedMetadata: map[string]interface{}{}, + ExpectedMsgCount: 1, + ExpectedPriority: false, + }, + { + ChannelUUID: "", + Text: "missing Channel", + Contact: testdata.Cathy, + URN: urns.NilURN, + URNID: models.URNID(0), + Flow: testdata.Favorites, + SuspendedOrg: false, + ExpectedStatus: models.MsgStatusFailed, + ExpectedFailedReason: models.MsgFailedNoDestination, + ExpectedMetadata: map[string]interface{}{}, + ExpectedMsgCount: 1, + ExpectedPriority: false, }, } now := time.Now() for _, tc := range tcs { - tx, err := db.BeginTxx(ctx, nil) - require.NoError(t, err) - db.MustExec(`UPDATE orgs_org SET is_suspended = $1 WHERE id = $2`, tc.SuspendedOrg, testdata.Org1.ID) oa, err := models.GetOrgAssetsWithRefresh(ctx, rt, testdata.Org1.ID, models.RefreshOrg) require.NoError(t, err) channel := oa.ChannelByUUID(tc.ChannelUUID) + flow, _ := oa.FlowByID(tc.Flow.ID) + + session := insertTestSession(t, ctx, rt, testdata.Org1, testdata.Cathy, testdata.Favorites) + if tc.ResponseTo != models.NilMsgID { + session.SetIncomingMsg(flows.MsgID(tc.ResponseTo), null.NullString) + } flowMsg := flows.NewMsgOut(tc.URN, assets.NewChannelReference(tc.ChannelUUID, "Test Channel"), tc.Text, tc.Attachments, tc.QuickReplies, nil, tc.Topic) - msg, err := models.NewOutgoingMsg(rt.Config, oa.Org(), channel, tc.ContactID, flowMsg, now) + msg, err := models.NewOutgoingFlowMsg(rt, oa.Org(), channel, session, flow, flowMsg, now) + + assert.NoError(t, err) - if tc.HasError { - assert.Error(t, err) + err = models.InsertMessages(ctx, db, []*models.Msg{msg}) + assert.NoError(t, err) + assert.Equal(t, oa.OrgID(), msg.OrgID()) + assert.Equal(t, tc.Text, msg.Text()) + assert.Equal(t, tc.Contact.ID, msg.ContactID()) + assert.Equal(t, channel, msg.Channel()) + assert.Equal(t, tc.ChannelUUID, msg.ChannelUUID()) + assert.Equal(t, tc.URN, msg.URN()) + if tc.URNID != models.NilURNID { + assert.Equal(t, tc.URNID, *msg.ContactURNID()) } else { - assert.NoError(t, err) - - err = models.InsertMessages(ctx, tx, []*models.Msg{msg}) - assert.NoError(t, err) - assert.Equal(t, oa.OrgID(), msg.OrgID()) - assert.Equal(t, tc.Text, msg.Text()) - assert.Equal(t, tc.ContactID, msg.ContactID()) - assert.Equal(t, channel, msg.Channel()) - assert.Equal(t, tc.ChannelUUID, msg.ChannelUUID()) - assert.Equal(t, tc.URN, msg.URN()) - if tc.URNID != models.NilURNID { - assert.Equal(t, tc.URNID, *msg.ContactURNID()) - } else { - assert.Nil(t, msg.ContactURNID()) - } - - assert.Equal(t, tc.ExpectedStatus, msg.Status()) - assert.Equal(t, tc.ExpectedMetadata, msg.Metadata()) - assert.Equal(t, tc.ExpectedMsgCount, msg.MsgCount()) - assert.Equal(t, now, msg.CreatedOn()) - assert.True(t, msg.ID() > 0) - assert.True(t, msg.QueuedOn().After(now)) - assert.True(t, msg.ModifiedOn().After(now)) + assert.Nil(t, msg.ContactURNID()) } + assert.Equal(t, tc.Flow.ID, msg.FlowID()) + + assert.Equal(t, tc.ExpectedStatus, msg.Status()) + assert.Equal(t, tc.ExpectedFailedReason, msg.FailedReason()) + assert.Equal(t, tc.ExpectedMetadata, msg.Metadata()) + assert.Equal(t, tc.ExpectedMsgCount, msg.MsgCount()) + assert.Equal(t, now, msg.CreatedOn()) + assert.True(t, msg.ID() > 0) + assert.True(t, msg.QueuedOn().After(now)) + assert.True(t, msg.ModifiedOn().After(now)) + } + + // check nil failed reasons are saved as NULLs + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE failed_reason IS NOT NULL`).Returns(3) + + // ensure org is unsuspended + db.MustExec(`UPDATE orgs_org SET is_suspended = FALSE`) + models.FlushCache() + + oa, err := models.GetOrgAssetsWithRefresh(ctx, rt, testdata.Org1.ID, models.RefreshOrg) + require.NoError(t, err) + channel := oa.ChannelByUUID(testdata.TwilioChannel.UUID) + flow, _ := oa.FlowByID(testdata.Favorites.ID) + session := insertTestSession(t, ctx, rt, testdata.Org1, testdata.Cathy, testdata.Favorites) + + // check that msg loop detection triggers after 20 repeats of the same text + newOutgoing := func(text string) *models.Msg { + flowMsg := flows.NewMsgOut(urns.URN(fmt.Sprintf("tel:+250700000001?id=%d", testdata.Cathy.URNID)), assets.NewChannelReference(testdata.TwilioChannel.UUID, "Twilio"), text, nil, nil, nil, flows.NilMsgTopic) + msg, err := models.NewOutgoingFlowMsg(rt, oa.Org(), channel, session, flow, flowMsg, now) + require.NoError(t, err) + return msg + } - tx.Rollback() + for i := 0; i < 19; i++ { + msg := newOutgoing("foo") + assert.Equal(t, models.MsgStatusQueued, msg.Status()) + assert.Equal(t, models.NilMsgFailedReason, msg.FailedReason()) + } + for i := 0; i < 10; i++ { + msg := newOutgoing("foo") + assert.Equal(t, models.MsgStatusFailed, msg.Status()) + assert.Equal(t, models.MsgFailedLooping, msg.FailedReason()) + } + for i := 0; i < 5; i++ { + msg := newOutgoing("bar") + assert.Equal(t, models.MsgStatusQueued, msg.Status()) + assert.Equal(t, models.NilMsgFailedReason, msg.FailedReason()) } } -func TestGetMessageIDFromUUID(t *testing.T) { - ctx, _, db, _ := testsuite.Get() +func TestMarshalMsg(t *testing.T) { + ctx, rt, db, _ := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetData) + + assertdb.Query(t, db, `SELECT count(*) FROM orgs_org WHERE is_suspended = TRUE`).Returns(0) + + oa, err := models.GetOrgAssets(ctx, rt, testdata.Org1.ID) + require.NoError(t, err) + require.False(t, oa.Org().Suspended()) + + channel := oa.ChannelByUUID(testdata.TwilioChannel.UUID) + flow, _ := oa.FlowByID(testdata.Favorites.ID) + urn := urns.URN(fmt.Sprintf("tel:+250700000001?id=%d", testdata.Cathy.URNID)) + flowMsg1 := flows.NewMsgOut( + urn, + assets.NewChannelReference(testdata.TwilioChannel.UUID, "Test Channel"), + "Hi there", + []utils.Attachment{utils.Attachment("image/jpeg:https://dl-foo.com/image.jpg")}, + []string{"yes", "no"}, + nil, + flows.MsgTopicPurchase, + ) + + // create a non-priority flow message.. i.e. the session isn't responding to an incoming message + session := insertTestSession(t, ctx, rt, testdata.Org1, testdata.Cathy, testdata.Favorites) + msg1, err := models.NewOutgoingFlowMsg(rt, oa.Org(), channel, session, flow, flowMsg1, time.Date(2021, 11, 9, 14, 3, 30, 0, time.UTC)) + require.NoError(t, err) + + err = models.InsertMessages(ctx, db, []*models.Msg{msg1}) + require.NoError(t, err) + + marshaled, err := json.Marshal(msg1) + assert.NoError(t, err) + + test.AssertEqualJSON(t, []byte(fmt.Sprintf(`{ + "attachments": [ + "image/jpeg:https://dl-foo.com/image.jpg" + ], + "channel_id": 10000, + "channel_uuid": "74729f45-7f29-4868-9dc4-90e491e3c7d8", + "contact_id": 10000, + "contact_urn_id": 10000, + "created_on": "2021-11-09T14:03:30Z", + "direction": "O", + "error_count": 0, + "flow": {"uuid": "9de3663f-c5c5-4c92-9f45-ecbc09abcc85", "name": "Favorites"}, + "high_priority": false, + "id": %d, + "metadata": { + "quick_replies": [ + "yes", + "no" + ], + "topic": "purchase" + }, + "modified_on": %s, + "next_attempt": null, + "org_id": 1, + "queued_on": %s, + "sent_on": null, + "session_id": %d, + "session_status": "W", + "status": "Q", + "text": "Hi there", + "tps_cost": 2, + "urn": "tel:+250700000001?id=10000", + "uuid": "%s" + }`, msg1.ID(), jsonx.MustMarshal(msg1.ModifiedOn()), jsonx.MustMarshal(msg1.QueuedOn()), session.ID(), msg1.UUID())), marshaled) + + // create a priority flow message.. i.e. the session is responding to an incoming message + flowMsg2 := flows.NewMsgOut( + urn, + assets.NewChannelReference(testdata.TwilioChannel.UUID, "Test Channel"), + "Hi there", + nil, nil, nil, + flows.NilMsgTopic, + ) + in1 := testdata.InsertIncomingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "test", models.MsgStatusHandled) + session.SetIncomingMsg(flows.MsgID(in1.ID()), null.String("EX123")) + msg2, err := models.NewOutgoingFlowMsg(rt, oa.Org(), channel, session, flow, flowMsg2, time.Date(2021, 11, 9, 14, 3, 30, 0, time.UTC)) + require.NoError(t, err) + + err = models.InsertMessages(ctx, db, []*models.Msg{msg2}) + require.NoError(t, err) - msgIn := testdata.InsertIncomingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "hi there", models.MsgStatusHandled) + marshaled, err = json.Marshal(msg2) + assert.NoError(t, err) - msgID, err := models.GetMessageIDFromUUID(ctx, db, msgIn.UUID()) + test.AssertEqualJSON(t, []byte(fmt.Sprintf(`{ + "channel_id": 10000, + "channel_uuid": "74729f45-7f29-4868-9dc4-90e491e3c7d8", + "contact_id": 10000, + "contact_urn_id": 10000, + "created_on": "2021-11-09T14:03:30Z", + "direction": "O", + "error_count": 0, + "flow": {"uuid": "9de3663f-c5c5-4c92-9f45-ecbc09abcc85", "name": "Favorites"}, + "response_to_external_id": "EX123", + "high_priority": true, + "id": %d, + "metadata": null, + "modified_on": %s, + "next_attempt": null, + "org_id": 1, + "queued_on": %s, + "sent_on": null, + "session_id": %d, + "session_status": "W", + "status": "Q", + "text": "Hi there", + "tps_cost": 1, + "urn": "tel:+250700000001?id=10000", + "uuid": "%s" + }`, msg2.ID(), jsonx.MustMarshal(msg2.ModifiedOn()), jsonx.MustMarshal(msg2.QueuedOn()), session.ID(), msg2.UUID())), marshaled) + + // try a broadcast message which won't have session and flow fields set + bcastID := testdata.InsertBroadcast(db, testdata.Org1, `eng`, map[envs.Language]string{`eng`: "Blast"}, models.NilScheduleID, []*testdata.Contact{testdata.Cathy}, nil) + bcastMsg1 := flows.NewMsgOut(urn, assets.NewChannelReference(testdata.TwilioChannel.UUID, "Test Channel"), "Blast", nil, nil, nil, flows.NilMsgTopic) + msg3, err := models.NewOutgoingBroadcastMsg(rt, oa.Org(), channel, testdata.Cathy.ID, bcastMsg1, time.Date(2021, 11, 9, 14, 3, 30, 0, time.UTC), bcastID) + require.NoError(t, err) + err = models.InsertMessages(ctx, db, []*models.Msg{msg2}) require.NoError(t, err) - assert.Equal(t, models.MsgID(msgIn.ID()), msgID) + + marshaled, err = json.Marshal(msg3) + assert.NoError(t, err) + + test.AssertEqualJSON(t, []byte(fmt.Sprintf(`{ + "broadcast_id": %d, + "channel_id": 10000, + "channel_uuid": "74729f45-7f29-4868-9dc4-90e491e3c7d8", + "contact_id": 10000, + "contact_urn_id": 10000, + "created_on": "2021-11-09T14:03:30Z", + "direction": "O", + "error_count": 0, + "high_priority": false, + "id": %d, + "metadata": null, + "modified_on": %s, + "next_attempt": null, + "org_id": 1, + "queued_on": %s, + "sent_on": null, + "status": "Q", + "text": "Blast", + "tps_cost": 1, + "urn": "tel:+250700000001?id=10000", + "uuid": "%s" + }`, bcastID, msg3.ID(), jsonx.MustMarshal(msg3.ModifiedOn()), jsonx.MustMarshal(msg3.QueuedOn()), msg3.UUID())), marshaled) } -func TestLoadMessages(t *testing.T) { +func TestGetMessagesByID(t *testing.T) { ctx, _, db, _ := testsuite.Get() + defer testsuite.Reset(testsuite.ResetData) + msgIn1 := testdata.InsertIncomingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "in 1", models.MsgStatusHandled) - msgOut1 := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "out 1", []utils.Attachment{"image/jpeg:hi.jpg"}, models.MsgStatusSent) - msgOut2 := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "out 2", nil, models.MsgStatusSent) - msgOut3 := testdata.InsertOutgoingMsg(db, testdata.Org2, testdata.Org2Channel, testdata.Org2Contact, "out 3", nil, models.MsgStatusSent) - testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "hi 3", nil, models.MsgStatusSent) + msgOut1 := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "out 1", []utils.Attachment{"image/jpeg:hi.jpg"}, models.MsgStatusSent, false) + msgOut2 := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "out 2", nil, models.MsgStatusSent, false) + msgOut3 := testdata.InsertOutgoingMsg(db, testdata.Org2, testdata.Org2Channel, testdata.Org2Contact, "out 3", nil, models.MsgStatusSent, false) + testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "hi 3", nil, models.MsgStatusSent, false) ids := []models.MsgID{models.MsgID(msgIn1.ID()), models.MsgID(msgOut1.ID()), models.MsgID(msgOut2.ID()), models.MsgID(msgOut3.ID())} - msgs, err := models.LoadMessages(ctx, db, testdata.Org1.ID, models.DirectionOut, ids) + msgs, err := models.GetMessagesByID(ctx, db, testdata.Org1.ID, models.DirectionOut, ids) // should only return the outgoing messages for this org require.NoError(t, err) @@ -167,7 +400,7 @@ func TestLoadMessages(t *testing.T) { assert.Equal(t, []utils.Attachment{"image/jpeg:hi.jpg"}, msgs[0].Attachments()) assert.Equal(t, "out 2", msgs[1].Text()) - msgs, err = models.LoadMessages(ctx, db, testdata.Org1.ID, models.DirectionIn, ids) + msgs, err = models.GetMessagesByID(ctx, db, testdata.Org1.ID, models.DirectionIn, ids) // should only return the incoming message for this org require.NoError(t, err) @@ -183,14 +416,14 @@ func TestResendMessages(t *testing.T) { oa, err := models.GetOrgAssets(ctx, rt, testdata.Org1.ID) require.NoError(t, err) - msgOut1 := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "out 1", nil, models.MsgStatusFailed) - msgOut2 := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Bob, "out 2", nil, models.MsgStatusFailed) - testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "out 3", nil, models.MsgStatusFailed) + msgOut1 := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "out 1", nil, models.MsgStatusFailed, false) + msgOut2 := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Bob, "out 2", nil, models.MsgStatusFailed, false) + testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "out 3", nil, models.MsgStatusFailed, false) // give Bob's URN an affinity for the Vonage channel db.MustExec(`UPDATE contacts_contacturn SET channel_id = $1 WHERE id = $2`, testdata.VonageChannel.ID, testdata.Bob.URNID) - msgs, err := models.LoadMessages(ctx, db, testdata.Org1.ID, models.DirectionOut, []models.MsgID{models.MsgID(msgOut1.ID()), models.MsgID(msgOut2.ID())}) + msgs, err := models.GetMessagesByID(ctx, db, testdata.Org1.ID, models.DirectionOut, []models.MsgID{models.MsgID(msgOut1.ID()), models.MsgID(msgOut2.ID())}) require.NoError(t, err) now := dates.Now() @@ -207,7 +440,46 @@ func TestResendMessages(t *testing.T) { assert.Equal(t, testdata.VonageChannel.ID, msgs[1].ChannelID()) assert.Equal(t, models.TopupID(1), msgs[1].TopupID()) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'P' AND queued_on > $1 AND sent_on IS NULL`, now).Returns(2) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'P' AND queued_on > $1 AND sent_on IS NULL`, now).Returns(2) +} + +func TestGetMsgRepetitions(t *testing.T) { + _, _, _, rp := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetRedis) + defer dates.SetNowSource(dates.DefaultNowSource) + + dates.SetNowSource(dates.NewFixedNowSource(time.Date(2021, 11, 18, 12, 13, 3, 234567, time.UTC))) + + msg1 := flows.NewMsgOut(testdata.Cathy.URN, nil, "foo", nil, nil, nil, flows.NilMsgTopic) + msg2 := flows.NewMsgOut(testdata.Cathy.URN, nil, "bar", nil, nil, nil, flows.NilMsgTopic) + + assertRepetitions := func(m *flows.MsgOut, expected int) { + count, err := models.GetMsgRepetitions(rp, testdata.Cathy.ID, m) + require.NoError(t, err) + assert.Equal(t, expected, count) + } + + // keep counts up to 99 + for i := 0; i < 99; i++ { + assertRepetitions(msg1, i+1) + } + assertredis.HGetAll(t, rp, "msg_repetitions:2021-11-18T12:15", map[string]string{"10000": "99:foo"}) + + for i := 0; i < 50; i++ { + assertRepetitions(msg1, 99) + } + assertredis.HGetAll(t, rp, "msg_repetitions:2021-11-18T12:15", map[string]string{"10000": "99:foo"}) + + for i := 0; i < 19; i++ { + assertRepetitions(msg2, i+1) + } + assertredis.HGetAll(t, rp, "msg_repetitions:2021-11-18T12:15", map[string]string{"10000": "19:bar"}) + + for i := 0; i < 50; i++ { + assertRepetitions(msg2, 20+i) + } + assertredis.HGetAll(t, rp, "msg_repetitions:2021-11-18T12:15", map[string]string{"10000": "69:bar"}) } func TestNormalizeAttachment(t *testing.T) { @@ -233,46 +505,48 @@ func TestNormalizeAttachment(t *testing.T) { } func TestMarkMessages(t *testing.T) { - ctx, rt, db, _ := testsuite.Get() + ctx, _, db, _ := testsuite.Get() defer testsuite.Reset(testsuite.ResetAll) - oa, err := models.GetOrgAssetsWithRefresh(ctx, rt, testdata.Org1.ID, models.RefreshOrg) + flowMsg1 := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "Hello", nil, models.MsgStatusQueued, false) + msgs, err := models.GetMessagesByID(ctx, db, testdata.Org1.ID, models.DirectionOut, []models.MsgID{models.MsgID(flowMsg1.ID())}) require.NoError(t, err) + msg1 := msgs[0] - channel := oa.ChannelByUUID(testdata.TwilioChannel.UUID) - - insertMsg := func(text string) *models.Msg { - urn := urns.URN(fmt.Sprintf("tel:+250700000001?id=%d", testdata.Cathy.URNID)) - flowMsg := flows.NewMsgOut(urn, channel.ChannelReference(), text, nil, nil, nil, flows.NilMsgTopic) - msg, err := models.NewOutgoingMsg(rt.Config, oa.Org(), channel, testdata.Cathy.ID, flowMsg, time.Now()) - require.NoError(t, err) - - err = models.InsertMessages(ctx, db, []*models.Msg{msg}) - require.NoError(t, err) - - return msg - } + flowMsg2 := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "Hola", nil, models.MsgStatusQueued, false) + msgs, err = models.GetMessagesByID(ctx, db, testdata.Org1.ID, models.DirectionOut, []models.MsgID{models.MsgID(flowMsg2.ID())}) + require.NoError(t, err) + msg2 := msgs[0] - msg1 := insertMsg("Hello") - msg2 := insertMsg("Hola") - insertMsg("Howdy") + testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "Howdy", nil, models.MsgStatusQueued, false) models.MarkMessagesPending(ctx, db, []*models.Msg{msg1, msg2}) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'P'`).Returns(2) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'P'`).Returns(2) // try running on database with BIGINT message ids db.MustExec(`ALTER SEQUENCE "msgs_msg_id_seq" AS bigint;`) db.MustExec(`ALTER SEQUENCE "msgs_msg_id_seq" RESTART WITH 3000000000;`) - msg4 := insertMsg("Big messages!") + flowMsg4 := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "Big messages!", nil, models.MsgStatusQueued, false) + msgs, err = models.GetMessagesByID(ctx, db, testdata.Org1.ID, models.DirectionOut, []models.MsgID{models.MsgID(flowMsg4.ID())}) + require.NoError(t, err) + msg4 := msgs[0] + assert.Equal(t, flows.MsgID(3000000000), msg4.ID()) err = models.MarkMessagesPending(ctx, db, []*models.Msg{msg4}) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'P'`).Returns(3) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'P'`).Returns(3) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'Q'`).Returns(1) + + err = models.MarkMessagesQueued(ctx, db, []*models.Msg{msg4}) + assert.NoError(t, err) + + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'P'`).Returns(2) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'Q'`).Returns(2) } func TestNonPersistentBroadcasts(t *testing.T) { @@ -326,10 +600,10 @@ func TestNonPersistentBroadcasts(t *testing.T) { assert.Equal(t, 2, len(msgs)) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE direction = 'O' AND broadcast_id IS NULL AND text = 'Hi there'`).Returns(2) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE direction = 'O' AND broadcast_id IS NULL AND text = 'Hi there'`).Returns(2) // test ticket was updated - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticket WHERE id = $1 AND last_activity_on > $2`, ticket.ID, modelTicket.LastActivityOn()).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticket WHERE id = $1 AND last_activity_on > $2`, ticket.ID, modelTicket.LastActivityOn()).Returns(1) } func TestNewOutgoingIVR(t *testing.T) { @@ -356,5 +630,19 @@ func TestNewOutgoingIVR(t *testing.T) { err = models.InsertMessages(ctx, db, []*models.Msg{dbMsg}) require.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT text, created_on, sent_on FROM msgs_msg WHERE uuid = $1`, dbMsg.UUID()).Columns(map[string]interface{}{"text": "Hello", "created_on": createdOn, "sent_on": createdOn}) + assertdb.Query(t, db, `SELECT text, created_on, sent_on FROM msgs_msg WHERE uuid = $1`, dbMsg.UUID()).Columns(map[string]interface{}{"text": "Hello", "created_on": createdOn, "sent_on": createdOn}) +} + +func insertTestSession(t *testing.T, ctx context.Context, rt *runtime.Runtime, org *testdata.Org, contact *testdata.Contact, flow *testdata.Flow) *models.Session { + testdata.InsertFlowSession(rt.DB, org, contact, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, models.NilConnectionID, nil) + + oa, err := models.GetOrgAssets(ctx, rt, testdata.Org1.ID) + require.NoError(t, err) + + _, flowContact := contact.Load(rt.DB, oa) + + session, err := models.FindWaitingSessionForContact(ctx, rt.DB, rt.SessionStorage, oa, models.FlowTypeMessaging, flowContact) + require.NoError(t, err) + + return session } diff --git a/core/models/notifications.go b/core/models/notifications.go index bc68acab2..90efc1ff2 100644 --- a/core/models/notifications.go +++ b/core/models/notifications.go @@ -3,9 +3,10 @@ package models import ( "context" "fmt" + "strconv" "time" - "github.com/nyaruka/mailroom/utils/dbutil" + "github.com/nyaruka/gocommon/dbutil" "github.com/pkg/errors" ) @@ -15,9 +16,9 @@ type NotificationID int type NotificationType string const ( - NotificationTypeChannelAlert NotificationType = "channel:alert" NotificationTypeExportFinished NotificationType = "export:finished" NotificationTypeImportFinished NotificationType = "import:finished" + NotificationTypeIncidentStarted NotificationType = "incident:started" NotificationTypeTicketsOpened NotificationType = "tickets:opened" NotificationTypeTicketsActivity NotificationType = "tickets:activity" ) @@ -40,11 +41,11 @@ type Notification struct { EmailStatus EmailStatus `db:"email_status"` CreatedOn time.Time `db:"created_on"` - ChannelID ChannelID `db:"channel_id"` ContactImportID ContactImportID `db:"contact_import_id"` + IncidentID IncidentID `db:"incident_id"` } -// NotifyImportFinished logs the the finishing of a contact import +// NotifyImportFinished notifies the user who created an import that it has finished func NotifyImportFinished(ctx context.Context, db Queryer, imp *ContactImport) error { n := &Notification{ OrgID: imp.OrgID, @@ -57,6 +58,24 @@ func NotifyImportFinished(ctx context.Context, db Queryer, imp *ContactImport) e return insertNotifications(ctx, db, []*Notification{n}) } +// NotifyIncidentStarted notifies administrators that an incident has started +func NotifyIncidentStarted(ctx context.Context, db Queryer, oa *OrgAssets, incident *Incident) error { + admins := usersWithRoles(oa, []UserRole{UserRoleAdministrator}) + notifications := make([]*Notification, len(admins)) + + for i, admin := range admins { + notifications[i] = &Notification{ + OrgID: incident.OrgID, + Type: NotificationTypeIncidentStarted, + Scope: strconv.Itoa(int(incident.ID)), + UserID: admin.ID(), + IncidentID: incident.ID, + } + } + + return insertNotifications(ctx, db, notifications) +} + var ticketAssignableToles = []UserRole{UserRoleAdministrator, UserRoleEditor, UserRoleAgent} // NotificationsFromTicketEvents logs the opening of new tickets and notifies all assignable users if tickets is not already assigned @@ -64,15 +83,15 @@ func NotificationsFromTicketEvents(ctx context.Context, db Queryer, oa *OrgAsset notifyTicketsOpened := make(map[UserID]bool) notifyTicketsActivity := make(map[UserID]bool) + assignableUsers := usersWithRoles(oa, ticketAssignableToles) + for ticket, evt := range events { switch evt.EventType() { case TicketEventTypeOpened: // if ticket is unassigned notify all possible assignees if evt.AssigneeID() == NilUserID { - for _, u := range oa.users { - user := u.(*User) - - if hasAnyRole(user, ticketAssignableToles) && evt.CreatedByID() != user.ID() { + for _, user := range assignableUsers { + if evt.CreatedByID() != user.ID() { notifyTicketsOpened[user.ID()] = true } } @@ -116,8 +135,8 @@ func NotificationsFromTicketEvents(ctx context.Context, db Queryer, oa *OrgAsset } const insertNotificationSQL = ` -INSERT INTO notifications_notification(org_id, notification_type, scope, user_id, is_seen, email_status, created_on, channel_id, contact_import_id) - VALUES(:org_id, :notification_type, :scope, :user_id, FALSE, 'N', NOW(), :channel_id, :contact_import_id) +INSERT INTO notifications_notification(org_id, notification_type, scope, user_id, is_seen, email_status, created_on, contact_import_id, incident_id) + VALUES(:org_id, :notification_type, :scope, :user_id, FALSE, 'N', NOW(), :contact_import_id, :incident_id) ON CONFLICT DO NOTHING` func insertNotifications(ctx context.Context, db Queryer, notifications []*Notification) error { @@ -130,6 +149,17 @@ func insertNotifications(ctx context.Context, db Queryer, notifications []*Notif return errors.Wrap(err, "error inserting notifications") } +func usersWithRoles(oa *OrgAssets, roles []UserRole) []*User { + users := make([]*User, 0, 5) + for _, u := range oa.users { + user := u.(*User) + if hasAnyRole(user, roles) { + users = append(users, user) + } + } + return users +} + func hasAnyRole(user *User, roles []UserRole) bool { for _, r := range roles { if user.Role() == r { diff --git a/core/models/notifications_test.go b/core/models/notifications_test.go index f042f7b3f..934e32f1c 100644 --- a/core/models/notifications_test.go +++ b/core/models/notifications_test.go @@ -127,6 +127,24 @@ func TestImportNotifications(t *testing.T) { }) } +func TestIncidentNotifications(t *testing.T) { + ctx, rt, db, rp := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetData) + + oa, err := models.GetOrgAssets(ctx, rt, testdata.Org1.ID) + require.NoError(t, err) + + t0 := time.Now() + + _, err = models.IncidentWebhooksUnhealthy(ctx, db, rp, oa, nil) + require.NoError(t, err) + + assertNotifications(t, ctx, db, t0, map[*testdata.User][]models.NotificationType{ + testdata.Admin: {models.NotificationTypeIncidentStarted}, + }) +} + func assertNotifications(t *testing.T, ctx context.Context, db *sqlx.DB, after time.Time, expected map[*testdata.User][]models.NotificationType) { // check last log var notifications []*models.Notification diff --git a/core/models/orgs.go b/core/models/orgs.go index a058170d1..41d4bc59f 100644 --- a/core/models/orgs.go +++ b/core/models/orgs.go @@ -11,6 +11,8 @@ import ( "strings" "time" + "github.com/jmoiron/sqlx" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/gocommon/httpx" "github.com/nyaruka/gocommon/jsonx" "github.com/nyaruka/goflow/envs" @@ -22,10 +24,7 @@ import ( "github.com/nyaruka/goflow/utils/smtpx" "github.com/nyaruka/mailroom/core/goflow" "github.com/nyaruka/mailroom/runtime" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/nyaruka/null" - - "github.com/jmoiron/sqlx" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -57,9 +56,6 @@ func airtimeServiceFactory(c *runtime.Config) engine.AirtimeServiceFactory { // OrgID is our type for orgs ids type OrgID int -// SessionStorageMode is our type for how we persist our sessions -type SessionStorageMode string - const ( // NilOrgID is the id 0 considered as nil org id NilOrgID = OrgID(0) @@ -67,11 +63,6 @@ const ( configSMTPServer = "smtp_server" configDTOneKey = "dtone_key" configDTOneSecret = "dtone_secret" - - configSessionStorageMode = "session_storage_mode" - - DBSessions = SessionStorageMode("db") - S3Sessions = SessionStorageMode("s3") ) // Org is mailroom's type for RapidPro orgs. It also implements the envs.Environment interface for GoFlow @@ -94,10 +85,6 @@ func (o *Org) Suspended() bool { return o.o.Suspended } // UsesTopups returns whether the org uses topups func (o *Org) UsesTopups() bool { return o.o.UsesTopups } -func (o *Org) SessionStorageMode() SessionStorageMode { - return SessionStorageMode(o.ConfigValue(configSessionStorageMode, string(DBSessions))) -} - // DateFormat returns the date format for this org func (o *Org) DateFormat() envs.DateFormat { return o.env.DateFormat() } @@ -251,7 +238,7 @@ func LoadOrg(ctx context.Context, cfg *runtime.Config, db sqlx.Queryer, orgID Or return nil, errors.Errorf("no org with id: %d", orgID) } - err = dbutil.ReadJSONRow(rows, org) + err = dbutil.ScanJSON(rows, org) if err != nil { return nil, errors.Wrapf(err, "error unmarshalling org") } diff --git a/core/models/resthooks.go b/core/models/resthooks.go index e5a9419fb..970c05a76 100644 --- a/core/models/resthooks.go +++ b/core/models/resthooks.go @@ -4,8 +4,8 @@ import ( "context" "time" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/goflow/assets" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/jmoiron/sqlx" "github.com/pkg/errors" @@ -46,7 +46,7 @@ func loadResthooks(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets. resthooks := make([]assets.Resthook, 0, 10) for rows.Next() { resthook := &Resthook{} - err = dbutil.ReadJSONRow(rows, &resthook.r) + err = dbutil.ScanJSON(rows, &resthook.r) if err != nil { return nil, errors.Wrap(err, "error scanning resthook row") } diff --git a/core/models/runs.go b/core/models/runs.go index 9d88ff33c..6cecd24fd 100644 --- a/core/models/runs.go +++ b/core/models/runs.go @@ -2,64 +2,33 @@ package models import ( "context" - "crypto/md5" "database/sql" "encoding/json" - "fmt" - "net/url" - "path" "time" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/nyaruka/gocommon/storage" - "github.com/nyaruka/gocommon/uuids" - "github.com/nyaruka/goflow/assets" - "github.com/nyaruka/goflow/envs" + "github.com/nyaruka/gocommon/jsonx" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/goflow/flows/events" - "github.com/nyaruka/mailroom/core/goflow" - "github.com/nyaruka/mailroom/runtime" "github.com/nyaruka/null" - "github.com/gomodule/redigo/redis" "github.com/jmoiron/sqlx" "github.com/lib/pq" "github.com/pkg/errors" - "github.com/sirupsen/logrus" ) -type SessionCommitHook func(context.Context, *sqlx.Tx, *redis.Pool, *OrgAssets, []*Session) error - -type SessionID int64 -type SessionStatus string - type FlowRunID int64 const NilFlowRunID = FlowRunID(0) -const ( - SessionStatusWaiting = "W" - SessionStatusCompleted = "C" - SessionStatusExpired = "X" - SessionStatusInterrupted = "I" - SessionStatusFailed = "F" -) - -var sessionStatusMap = map[flows.SessionStatus]SessionStatus{ - flows.SessionStatusWaiting: SessionStatusWaiting, - flows.SessionStatusCompleted: SessionStatusCompleted, - flows.SessionStatusFailed: SessionStatusFailed, -} - type RunStatus string const ( - RunStatusActive = "A" - RunStatusWaiting = "W" - RunStatusCompleted = "C" - RunStatusExpired = "X" - RunStatusInterrupted = "I" - RunStatusFailed = "F" + RunStatusActive RunStatus = "A" + RunStatusWaiting RunStatus = "W" + RunStatusCompleted RunStatus = "C" + RunStatusExpired RunStatus = "X" + RunStatusInterrupted RunStatus = "I" + RunStatusFailed RunStatus = "F" ) var runStatusMap = map[flows.RunStatus]RunStatus{ @@ -70,242 +39,36 @@ var runStatusMap = map[flows.RunStatus]RunStatus{ flows.RunStatusFailed: RunStatusFailed, } +// ExitType still needs to be set on runs until database triggers are updated to only look at status type ExitType = null.String -var ( +const ( ExitInterrupted = ExitType("I") ExitCompleted = ExitType("C") ExitExpired = ExitType("E") ExitFailed = ExitType("F") ) -var exitToSessionStatusMap = map[ExitType]SessionStatus{ - ExitInterrupted: SessionStatusInterrupted, - ExitCompleted: SessionStatusCompleted, - ExitExpired: SessionStatusExpired, - ExitFailed: SessionStatusFailed, -} - -var exitToRunStatusMap = map[ExitType]RunStatus{ - ExitInterrupted: RunStatusInterrupted, - ExitCompleted: RunStatusCompleted, - ExitExpired: RunStatusExpired, - ExitFailed: RunStatusFailed, -} - -var keptEvents = map[string]bool{ - events.TypeMsgCreated: true, - events.TypeMsgReceived: true, -} - -// Session is the mailroom type for a FlowSession -type Session struct { - s struct { - ID SessionID `db:"id"` - UUID flows.SessionUUID `db:"uuid"` - SessionType FlowType `db:"session_type"` - Status SessionStatus `db:"status"` - Responded bool `db:"responded"` - Output null.String `db:"output"` - OutputURL null.String `db:"output_url"` - ContactID ContactID `db:"contact_id"` - OrgID OrgID `db:"org_id"` - CreatedOn time.Time `db:"created_on"` - EndedOn *time.Time `db:"ended_on"` - TimeoutOn *time.Time `db:"timeout_on"` - WaitStartedOn *time.Time `db:"wait_started_on"` - CurrentFlowID FlowID `db:"current_flow_id"` - ConnectionID *ConnectionID `db:"connection_id"` - } - - incomingMsgID MsgID - incomingExternalID null.String - - // any channel connection associated with this flow session - channelConnection *ChannelConnection - - // time after our last message is sent that we should timeout - timeout *time.Duration - - contact *flows.Contact - runs []*FlowRun - - seenRuns map[flows.RunUUID]time.Time - - // we keep around a reference to the sprint associated with this session - sprint flows.Sprint - - // the scene for our event hooks - scene *Scene - - // we also keep around a reference to the wait (if any) - wait flows.ActivatedWait - - findStep func(flows.StepUUID) (flows.FlowRun, flows.Step) -} - -func (s *Session) ID() SessionID { return s.s.ID } -func (s *Session) UUID() flows.SessionUUID { return flows.SessionUUID(s.s.UUID) } -func (s *Session) SessionType() FlowType { return s.s.SessionType } -func (s *Session) Status() SessionStatus { return s.s.Status } -func (s *Session) Responded() bool { return s.s.Responded } -func (s *Session) Output() string { return string(s.s.Output) } -func (s *Session) OutputURL() string { return string(s.s.OutputURL) } -func (s *Session) ContactID() ContactID { return s.s.ContactID } -func (s *Session) OrgID() OrgID { return s.s.OrgID } -func (s *Session) CreatedOn() time.Time { return s.s.CreatedOn } -func (s *Session) EndedOn() *time.Time { return s.s.EndedOn } -func (s *Session) TimeoutOn() *time.Time { return s.s.TimeoutOn } -func (s *Session) ClearTimeoutOn() { s.s.TimeoutOn = nil } -func (s *Session) WaitStartedOn() *time.Time { return s.s.WaitStartedOn } -func (s *Session) CurrentFlowID() FlowID { return s.s.CurrentFlowID } -func (s *Session) ConnectionID() *ConnectionID { return s.s.ConnectionID } -func (s *Session) IncomingMsgID() MsgID { return s.incomingMsgID } -func (s *Session) IncomingMsgExternalID() null.String { return s.incomingExternalID } -func (s *Session) Scene() *Scene { return s.scene } - -// WriteSessionsToStorage writes the outputs of the passed in sessions to our storage (S3), updating the -// output_url for each on success. Failure of any will cause all to fail. -func WriteSessionOutputsToStorage(ctx context.Context, rt *runtime.Runtime, sessions []*Session) error { - start := time.Now() - - uploads := make([]*storage.Upload, len(sessions)) - for i, s := range sessions { - uploads[i] = &storage.Upload{ - Path: s.StoragePath(rt.Config), - Body: []byte(s.Output()), - ContentType: "application/json", - ACL: s3.ObjectCannedACLPrivate, - } - } - - err := rt.SessionStorage.BatchPut(ctx, uploads) - if err != nil { - return errors.Wrapf(err, "error writing sessions to storage") - } - - for i, s := range sessions { - s.s.OutputURL = null.String(uploads[i].URL) - } - - logrus.WithField("elapsed", time.Since(start)).WithField("count", len(sessions)).Debug("wrote sessions to s3") - - return nil -} - -const storageTSFormat = "20060102T150405.999Z" - -// StoragePath returns the path for the session -func (s *Session) StoragePath(cfg *runtime.Config) string { - ts := s.CreatedOn().UTC().Format(storageTSFormat) - - // example output: /orgs/1/c/20a5/20a5534c-b2ad-4f18-973a-f1aa3b4e6c74/session_20060102T150405.123Z_8a7fc501-177b-4567-a0aa-81c48e6de1c5_51df83ac21d3cf136d8341f0b11cb1a7.json" - return path.Join( - cfg.S3SessionPrefix, - "orgs", - fmt.Sprintf("%d", s.OrgID()), - "c", - string(s.ContactUUID()[:4]), - string(s.ContactUUID()), - fmt.Sprintf("%s_session_%s_%s.json", ts, s.UUID(), s.OutputMD5()), - ) -} - -// ContactUUID returns the UUID of our contact -func (s *Session) ContactUUID() flows.ContactUUID { - return s.contact.UUID() -} - -// Contact returns the contact for this session -func (s *Session) Contact() *flows.Contact { - return s.contact -} - -// Runs returns our flow run -func (s *Session) Runs() []*FlowRun { - return s.runs -} - -// Sprint returns the sprint associated with this session -func (s *Session) Sprint() flows.Sprint { - return s.sprint -} - -// Wait returns the wait associated with this session (if any) -func (s *Session) Wait() flows.ActivatedWait { - return s.wait -} - -// FindStep finds the run and step with the given UUID -func (s *Session) FindStep(uuid flows.StepUUID) (flows.FlowRun, flows.Step) { - return s.findStep(uuid) -} - -// Timeout returns the amount of time after our last message sends that we should timeout -func (s *Session) Timeout() *time.Duration { - return s.timeout -} - -// OutputMD5 returns the md5 of the passed in session -func (s *Session) OutputMD5() string { - return fmt.Sprintf("%x", md5.Sum([]byte(s.s.Output))) -} - -// SetIncomingMsg set the incoming message that this session should be associated with in this sprint -func (s *Session) SetIncomingMsg(id flows.MsgID, externalID null.String) { - s.incomingMsgID = MsgID(id) - s.incomingExternalID = externalID -} - -// SetChannelConnection sets the channel connection associated with this sprint -func (s *Session) SetChannelConnection(cc *ChannelConnection) { - connID := cc.ID() - s.s.ConnectionID = &connID - s.channelConnection = cc - - // also set it on all our runs - for _, r := range s.runs { - r.SetConnectionID(&connID) - } -} - -func (s *Session) ChannelConnection() *ChannelConnection { - return s.channelConnection -} - -// MarshalJSON is our custom marshaller so that our inner struct get output -func (s *Session) MarshalJSON() ([]byte, error) { - return json.Marshal(s.s) -} - -// UnmarshalJSON is our custom marshaller so that our inner struct get output -func (s *Session) UnmarshalJSON(b []byte) error { - return json.Unmarshal(b, &s.s) +var runStatusToExitType = map[RunStatus]ExitType{ + RunStatusInterrupted: ExitInterrupted, + RunStatusCompleted: ExitCompleted, + RunStatusExpired: ExitExpired, + RunStatusFailed: ExitFailed, } // FlowRun is the mailroom type for a FlowRun type FlowRun struct { r struct { - ID FlowRunID `db:"id"` - UUID flows.RunUUID `db:"uuid"` - Status RunStatus `db:"status"` - IsActive bool `db:"is_active"` - CreatedOn time.Time `db:"created_on"` - ModifiedOn time.Time `db:"modified_on"` - ExitedOn *time.Time `db:"exited_on"` - ExitType ExitType `db:"exit_type"` - ExpiresOn *time.Time `db:"expires_on"` - Responded bool `db:"responded"` - - // TODO: should this be a complex object that can read / write iself to the DB as JSON? - Results string `db:"results"` - - // TODO: should this be a complex object that can read / write iself to the DB as JSON? - Path string `db:"path"` - - // TODO: should this be a complex object that can read / write iself to the DB as JSON? - Events string `db:"events"` - + ID FlowRunID `db:"id"` + UUID flows.RunUUID `db:"uuid"` + Status RunStatus `db:"status"` + CreatedOn time.Time `db:"created_on"` + ModifiedOn time.Time `db:"modified_on"` + ExitedOn *time.Time `db:"exited_on"` + ExpiresOn *time.Time `db:"expires_on"` + Responded bool `db:"responded"` + Results string `db:"results"` + Path string `db:"path"` CurrentNodeUUID null.String `db:"current_node_uuid"` ContactID flows.ContactID `db:"contact_id"` FlowID FlowID `db:"flow_id"` @@ -314,10 +77,14 @@ type FlowRun struct { SessionID SessionID `db:"session_id"` StartID StartID `db:"start_id"` ConnectionID *ConnectionID `db:"connection_id"` + + // deprecated + IsActive bool `db:"is_active"` + ExitType ExitType `db:"exit_type"` } - // we keep a reference to model run as well - run flows.FlowRun + // we keep a reference to the engine's run + run flows.Run } func (r *FlowRun) SetSessionID(sessionID SessionID) { r.r.SessionID = sessionID } @@ -344,563 +111,18 @@ type Step struct { ExitUUID flows.ExitUUID `json:"exit_uuid,omitempty"` } -// NewSession a session objects from the passed in flow session. It does NOT -// commit said session to the database. -func NewSession(ctx context.Context, tx *sqlx.Tx, org *OrgAssets, fs flows.Session, sprint flows.Sprint) (*Session, error) { - output, err := json.Marshal(fs) - if err != nil { - return nil, errors.Wrapf(err, "error marshalling flow session") - } - - // map our status over - sessionStatus, found := sessionStatusMap[fs.Status()] - if !found { - return nil, errors.Errorf("unknown session status: %s", fs.Status()) - } - - // session must have at least one run - if len(fs.Runs()) < 1 { - return nil, errors.Errorf("cannot write session that has no runs") - } - - // figure out our type - sessionType, found := flowTypeMapping[fs.Type()] - if !found { - return nil, errors.Errorf("unknown flow type: %s", fs.Type()) - } - - uuid := fs.UUID() - if uuid == "" { - uuid = flows.SessionUUID(uuids.New()) - } - - // create our session object - session := &Session{} - s := &session.s - s.UUID = uuid - s.Status = sessionStatus - s.SessionType = sessionType - s.Responded = false - s.Output = null.String(output) - s.ContactID = ContactID(fs.Contact().ID()) - s.OrgID = org.OrgID() - s.CreatedOn = fs.Runs()[0].CreatedOn() - - session.contact = fs.Contact() - session.scene = NewSceneForSession(session) - - session.sprint = sprint - session.wait = fs.Wait() - session.findStep = fs.FindStep - - // now build up our runs - for _, r := range fs.Runs() { - run, err := newRun(ctx, tx, org, session, r) - if err != nil { - return nil, errors.Wrapf(err, "error creating run: %s", r.UUID()) - } - - // save the run to our session - session.runs = append(session.runs, run) - - // if this run is waiting, save it as the current flow - if r.Status() == flows.RunStatusWaiting { - flowID, err := FlowIDForUUID(ctx, tx, org, r.FlowReference().UUID) - if err != nil { - return nil, errors.Wrapf(err, "error loading current flow for UUID: %s", r.FlowReference().UUID) - } - s.CurrentFlowID = flowID - } - } - - // calculate our timeout if any - session.calculateTimeout(fs, sprint) - - return session, nil -} - -// ActiveSessionForContact returns the active session for the passed in contact, if any -func ActiveSessionForContact(ctx context.Context, db *sqlx.DB, st storage.Storage, org *OrgAssets, sessionType FlowType, contact *flows.Contact) (*Session, error) { - rows, err := db.QueryxContext(ctx, selectLastSessionSQL, sessionType, contact.ID()) - if err != nil { - return nil, errors.Wrapf(err, "error selecting active session") - } - defer rows.Close() - - // no rows? no sessions! - if !rows.Next() { - return nil, nil - } - - // scan in our session - session := &Session{ - contact: contact, - } - session.scene = NewSceneForSession(session) - - if err := rows.StructScan(&session.s); err != nil { - return nil, errors.Wrapf(err, "error scanning session") - } - - // load our output from storage if necessary - if session.OutputURL() != "" { - // strip just the path out of our output URL - u, err := url.Parse(session.OutputURL()) - if err != nil { - return nil, errors.Wrapf(err, "error parsing output URL: %s", session.OutputURL()) - } - - start := time.Now() - - _, output, err := st.Get(ctx, u.Path) - if err != nil { - return nil, errors.Wrapf(err, "error reading session from storage: %s", session.OutputURL()) - } - - logrus.WithField("elapsed", time.Since(start)).WithField("output_url", session.OutputURL()).Debug("loaded session from storage") - session.s.Output = null.String(output) - } - - return session, nil -} - -const selectLastSessionSQL = ` -SELECT - id, - uuid, - session_type, - status, - responded, - output, - output_url, - contact_id, - org_id, - created_on, - ended_on, - timeout_on, - wait_started_on, - current_flow_id, - connection_id -FROM - flows_flowsession fs -WHERE - session_type = $1 AND - contact_id = $2 AND - status = 'W' -ORDER BY - created_on DESC -LIMIT 1 -` - -const insertCompleteSessionSQL = ` -INSERT INTO - flows_flowsession( uuid, session_type, status, responded, output, output_url, contact_id, org_id, created_on, ended_on, wait_started_on, connection_id) - VALUES(:uuid,:session_type,:status,:responded,:output,:output_url,:contact_id,:org_id, NOW(), NOW(), NULL, :connection_id) -RETURNING id -` - -const insertIncompleteSessionSQL = ` -INSERT INTO - flows_flowsession( uuid, session_type, status, responded, output, output_url, contact_id, org_id, created_on, current_flow_id, timeout_on, wait_started_on, connection_id) - VALUES(:uuid,:session_type,:status,:responded,:output,:output_url,:contact_id,:org_id, NOW(), :current_flow_id,:timeout_on,:wait_started_on,:connection_id) -RETURNING id -` - -const insertCompleteSessionSQLNoOutput = ` -INSERT INTO - flows_flowsession( uuid, session_type, status, responded, output_url, contact_id, org_id, created_on, ended_on, wait_started_on, connection_id) - VALUES(:uuid,:session_type,:status,:responded, :output_url,:contact_id,:org_id, NOW(), NOW(), NULL, :connection_id) -RETURNING id -` - -const insertIncompleteSessionSQLNoOutput = ` -INSERT INTO - flows_flowsession( uuid, session_type, status, responded, output_url, contact_id, org_id, created_on, current_flow_id, timeout_on, wait_started_on, connection_id) - VALUES(:uuid,:session_type,:status,:responded, :output_url,:contact_id,:org_id, NOW(), :current_flow_id,:timeout_on,:wait_started_on,:connection_id) -RETURNING id -` - -// FlowSession creates a flow session for the passed in session object. It also populates the runs we know about -func (s *Session) FlowSession(cfg *runtime.Config, sa flows.SessionAssets, env envs.Environment) (flows.Session, error) { - session, err := goflow.Engine(cfg).ReadSession(sa, json.RawMessage(s.s.Output), assets.IgnoreMissing) - if err != nil { - return nil, errors.Wrapf(err, "unable to unmarshal session") - } - - // walk through our session, populate seen runs - s.seenRuns = make(map[flows.RunUUID]time.Time, len(session.Runs())) - for _, r := range session.Runs() { - s.seenRuns[r.UUID()] = r.ModifiedOn() - } - - return session, nil -} - -// calculates the timeout value for this session based on our waits -func (s *Session) calculateTimeout(fs flows.Session, sprint flows.Sprint) { - // if we are on a wait and it has a timeout - if fs.Wait() != nil && fs.Wait().TimeoutSeconds() != nil { - now := time.Now() - s.s.WaitStartedOn = &now - - seconds := time.Duration(*fs.Wait().TimeoutSeconds()) * time.Second - s.timeout = &seconds - - timeoutOn := now.Add(seconds) - s.s.TimeoutOn = &timeoutOn - } else { - s.s.WaitStartedOn = nil - s.s.TimeoutOn = nil - s.timeout = nil - } -} - -// WriteUpdatedSession updates the session based on the state passed in from our engine session, this also takes care of applying any event hooks -func (s *Session) WriteUpdatedSession(ctx context.Context, rt *runtime.Runtime, tx *sqlx.Tx, org *OrgAssets, fs flows.Session, sprint flows.Sprint, hook SessionCommitHook) error { - // make sure we have our seen runs - if s.seenRuns == nil { - return errors.Errorf("missing seen runs, cannot update session") - } - - output, err := json.Marshal(fs) - if err != nil { - return errors.Wrapf(err, "error marshalling flow session") - } - s.s.Output = null.String(output) - - // map our status over - status, found := sessionStatusMap[fs.Status()] - if !found { - return errors.Errorf("unknown session status: %s", fs.Status()) - } - s.s.Status = status - - // now build up our runs - for _, r := range fs.Runs() { - run, err := newRun(ctx, tx, org, s, r) - if err != nil { - return errors.Wrapf(err, "error creating run: %s", r.UUID()) - } - - // set the run on our session - s.runs = append(s.runs, run) - } - - // calculate our new timeout - s.calculateTimeout(fs, sprint) - - // set our sprint, wait and step finder - s.sprint = sprint - s.wait = fs.Wait() - s.findStep = fs.FindStep - - // run through our runs to figure out our current flow - for _, r := range fs.Runs() { - // if this run is waiting, save it as the current flow - if r.Status() == flows.RunStatusWaiting { - flowID, err := FlowIDForUUID(ctx, tx, org, r.FlowReference().UUID) - if err != nil { - return errors.Wrapf(err, "error loading flow: %s", r.FlowReference().UUID) - } - s.s.CurrentFlowID = flowID - } - - // if we haven't already been marked as responded, walk our runs looking for an input - if !s.s.Responded { - // run through events, see if any are received events - for _, e := range r.Events() { - if e.Type() == events.TypeMsgReceived { - s.s.Responded = true - break - } - } - } - } - - // apply all our pre write events - for _, e := range sprint.Events() { - err := ApplyPreWriteEvent(ctx, rt, tx, org, s.scene, e) - if err != nil { - return errors.Wrapf(err, "error applying event: %v", e) - } - } - - // the SQL statement we'll use to update this session - updateSQL := updateSessionSQL - - // if writing to S3, do so - sessionMode := org.Org().SessionStorageMode() - if sessionMode == S3Sessions { - err := WriteSessionOutputsToStorage(ctx, rt, []*Session{s}) - if err != nil { - logrus.WithError(err).Error("error writing session to s3") - } - - // don't write output in our SQL - updateSQL = updateSessionSQLNoOutput - } - - // write our new session state to the db - _, err = tx.NamedExecContext(ctx, updateSQL, s.s) - if err != nil { - return errors.Wrapf(err, "error updating session") - } - - // if this session is complete, so is any associated connection - if s.channelConnection != nil { - if s.Status() == SessionStatusCompleted || s.Status() == SessionStatusFailed { - err := s.channelConnection.UpdateStatus(ctx, tx, ConnectionStatusCompleted, 0, time.Now()) - if err != nil { - return errors.Wrapf(err, "error update channel connection") - } - } - } - - // figure out which runs are new and which are updated - updatedRuns := make([]interface{}, 0, 1) - newRuns := make([]interface{}, 0) - for _, r := range s.Runs() { - modified, found := s.seenRuns[r.UUID()] - if !found { - newRuns = append(newRuns, &r.r) - continue - } - - if r.ModifiedOn().After(modified) { - updatedRuns = append(updatedRuns, &r.r) - continue - } - } - - // call our global pre commit hook if present - if hook != nil { - err := hook(ctx, tx, rt.RP, org, []*Session{s}) - if err != nil { - return errors.Wrapf(err, "error calling commit hook: %v", hook) - } - } - - // update all modified runs at once - err = BulkQuery(ctx, "update runs", tx, updateRunSQL, updatedRuns) - if err != nil { - logrus.WithError(err).WithField("session", string(output)).Error("error while updating runs for session") - return errors.Wrapf(err, "error updating runs") - } - - // insert all new runs at once - err = BulkQuery(ctx, "insert runs", tx, insertRunSQL, newRuns) - if err != nil { - return errors.Wrapf(err, "error writing runs") - } - - // apply all our events - if s.Status() != SessionStatusFailed { - err = HandleEvents(ctx, rt, tx, org, s.scene, sprint.Events()) - if err != nil { - return errors.Wrapf(err, "error applying events: %d", s.ID()) - } - } - - // gather all our pre commit events, group them by hook and apply them - err = ApplyEventPreCommitHooks(ctx, rt, tx, org, []*Scene{s.scene}) - if err != nil { - return errors.Wrapf(err, "error applying pre commit hook: %T", hook) - } - - return nil -} - -const updateSessionSQL = ` -UPDATE - flows_flowsession -SET - output = :output, - output_url = :output_url, - status = :status, - ended_on = CASE WHEN :status = 'W' THEN NULL ELSE NOW() END, - responded = :responded, - current_flow_id = :current_flow_id, - timeout_on = :timeout_on, - wait_started_on = :wait_started_on -WHERE - id = :id -` - -const updateSessionSQLNoOutput = ` -UPDATE - flows_flowsession -SET - output_url = :output_url, - status = :status, - ended_on = CASE WHEN :status = 'W' THEN NULL ELSE NOW() END, - responded = :responded, - current_flow_id = :current_flow_id, - timeout_on = :timeout_on, - wait_started_on = :wait_started_on -WHERE - id = :id -` - -const updateRunSQL = ` -UPDATE - flows_flowrun fr -SET - is_active = r.is_active::bool, - exit_type = r.exit_type, - status = r.status, - exited_on = r.exited_on::timestamp with time zone, - expires_on = r.expires_on::timestamp with time zone, - responded = r.responded::bool, - results = r.results, - path = r.path::jsonb, - events = r.events::jsonb, - current_node_uuid = r.current_node_uuid::uuid, - modified_on = NOW() -FROM ( - VALUES(:uuid, :is_active, :exit_type, :status, :exited_on, :expires_on, :responded, :results, :path, :events, :current_node_uuid) -) AS - r(uuid, is_active, exit_type, status, exited_on, expires_on, responded, results, path, events, current_node_uuid) -WHERE - fr.uuid = r.uuid::uuid -` - -// WriteSessions writes the passed in session to our database, writes any runs that need to be created -// as well as appying any events created in the session -func WriteSessions(ctx context.Context, rt *runtime.Runtime, tx *sqlx.Tx, org *OrgAssets, ss []flows.Session, sprints []flows.Sprint, hook SessionCommitHook) ([]*Session, error) { - if len(ss) == 0 { - return nil, nil - } - - // create all our session objects - sessions := make([]*Session, 0, len(ss)) - completeSessionsI := make([]interface{}, 0, len(ss)) - incompleteSessionsI := make([]interface{}, 0, len(ss)) - completedConnectionIDs := make([]ConnectionID, 0, 1) - for i, s := range ss { - session, err := NewSession(ctx, tx, org, s, sprints[i]) - if err != nil { - return nil, errors.Wrapf(err, "error creating session objects") - } - sessions = append(sessions, session) - - if session.Status() == SessionStatusCompleted { - completeSessionsI = append(completeSessionsI, &session.s) - if session.channelConnection != nil { - completedConnectionIDs = append(completedConnectionIDs, session.channelConnection.ID()) - } - } else { - incompleteSessionsI = append(incompleteSessionsI, &session.s) - } - } - - // apply all our pre write events - for i := range ss { - for _, e := range sprints[i].Events() { - err := ApplyPreWriteEvent(ctx, rt, tx, org, sessions[i].scene, e) - if err != nil { - return nil, errors.Wrapf(err, "error applying event: %v", e) - } - } - } - - // call our global pre commit hook if present - if hook != nil { - err := hook(ctx, tx, rt.RP, org, sessions) - if err != nil { - return nil, errors.Wrapf(err, "error calling commit hook: %v", hook) - } - } - - // the SQL we'll use to do our insert of complete sessions - insertCompleteSQL := insertCompleteSessionSQL - insertIncompleteSQL := insertIncompleteSessionSQL - - // if writing our sessions to S3, do so - sessionMode := org.Org().SessionStorageMode() - if sessionMode == S3Sessions { - err := WriteSessionOutputsToStorage(ctx, rt, sessions) - if err != nil { - // for now, continue on for errors, we are still reading from the DB - logrus.WithError(err).Error("error writing sessions to s3") - } - - insertCompleteSQL = insertCompleteSessionSQLNoOutput - insertIncompleteSQL = insertIncompleteSessionSQLNoOutput - } - - // insert our complete sessions first - err := BulkQuery(ctx, "insert completed sessions", tx, insertCompleteSQL, completeSessionsI) - if err != nil { - return nil, errors.Wrapf(err, "error inserting completed sessions") - } - - // mark any connections that are done as complete as well - err = UpdateChannelConnectionStatuses(ctx, tx, completedConnectionIDs, ConnectionStatusCompleted) - if err != nil { - return nil, errors.Wrapf(err, "error updating channel connections to complete") - } - - // insert incomplete sessions - err = BulkQuery(ctx, "insert incomplete sessions", tx, insertIncompleteSQL, incompleteSessionsI) - if err != nil { - return nil, errors.Wrapf(err, "error inserting incomplete sessions") - } - - // for each session associate our run with each - runs := make([]interface{}, 0, len(sessions)) - for _, s := range sessions { - for _, r := range s.runs { - runs = append(runs, &r.r) - - // set our session id now that it is written - r.SetSessionID(s.ID()) - } - } - - // insert all runs - err = BulkQuery(ctx, "insert runs", tx, insertRunSQL, runs) - if err != nil { - return nil, errors.Wrapf(err, "error writing runs") - } - - // apply our all events for the session - scenes := make([]*Scene, 0, len(ss)) - for i := range sessions { - if ss[i].Status() == SessionStatusFailed { - continue - } - - err = HandleEvents(ctx, rt, tx, org, sessions[i].Scene(), sprints[i].Events()) - if err != nil { - return nil, errors.Wrapf(err, "error applying events for session: %d", sessions[i].ID()) - } - - scene := sessions[i].Scene() - scenes = append(scenes, scene) - } - - // gather all our pre commit events, group them by hook - err = ApplyEventPreCommitHooks(ctx, rt, tx, org, scenes) - if err != nil { - return nil, errors.Wrapf(err, "error applying pre commit hook: %T", hook) - } - - // return our session - return sessions, nil -} - const insertRunSQL = ` INSERT INTO flows_flowrun(uuid, is_active, created_on, modified_on, exited_on, exit_type, status, expires_on, responded, results, path, - events, current_node_uuid, contact_id, flow_id, org_id, session_id, start_id, parent_uuid, connection_id) + current_node_uuid, contact_id, flow_id, org_id, session_id, start_id, parent_uuid, connection_id) VALUES(:uuid, :is_active, :created_on, NOW(), :exited_on, :exit_type, :status, :expires_on, :responded, :results, :path, - :events, :current_node_uuid, :contact_id, :flow_id, :org_id, :session_id, :start_id, :parent_uuid, :connection_id) + :current_node_uuid, :contact_id, :flow_id, :org_id, :session_id, :start_id, :parent_uuid, :connection_id) RETURNING id ` // newRun writes the passed in flow run to our database, also applying any events in those runs as // appropriate. (IE, writing db messages etc..) -func newRun(ctx context.Context, tx *sqlx.Tx, org *OrgAssets, session *Session, fr flows.FlowRun) (*FlowRun, error) { +func newRun(ctx context.Context, tx *sqlx.Tx, oa *OrgAssets, session *Session, fr flows.Run) (*FlowRun, error) { // build our path elements path := make([]Step, len(fr.Path())) for i, p := range fr.Path() { @@ -909,12 +131,8 @@ func newRun(ctx context.Context, tx *sqlx.Tx, org *OrgAssets, session *Session, path[i].ArrivedOn = p.ArrivedOn() path[i].ExitUUID = p.ExitUUID() } - pathJSON, err := json.Marshal(path) - if err != nil { - return nil, err - } - flowID, err := FlowIDForUUID(ctx, tx, org, fr.FlowReference().UUID) + flowID, err := FlowIDForUUID(ctx, tx, oa, fr.FlowReference().UUID) if err != nil { return nil, errors.Wrapf(err, "unable to load flow with uuid: %s", fr.FlowReference().UUID) } @@ -932,50 +150,30 @@ func newRun(ctx context.Context, tx *sqlx.Tx, org *OrgAssets, session *Session, r.FlowID = flowID r.SessionID = session.ID() r.StartID = NilStartID - r.OrgID = org.OrgID() - r.Path = string(pathJSON) + r.OrgID = oa.OrgID() + r.Path = string(jsonx.MustMarshal(path)) + r.Results = string(jsonx.MustMarshal(fr.Results())) + if len(path) > 0 { r.CurrentNodeUUID = null.String(path[len(path)-1].NodeUUID) } run.run = fr - // set our exit type if we exited - // TODO: audit exit types + // TODO remove once we no longer need to write is_active or exit_type if fr.Status() != flows.RunStatusActive && fr.Status() != flows.RunStatusWaiting { - if fr.Status() == flows.RunStatusFailed { - r.ExitType = ExitInterrupted - } else { - r.ExitType = ExitCompleted - } + r.ExitType = runStatusToExitType[r.Status] r.IsActive = false } else { r.IsActive = true } - // we filter which events we write to our events json right now - filteredEvents := make([]flows.Event, 0) + // mark ourselves as responded if we received a message for _, e := range fr.Events() { - if keptEvents[e.Type()] { - filteredEvents = append(filteredEvents, e) - } - - // mark ourselves as responded if we received a message if e.Type() == events.TypeMsgReceived { r.Responded = true + break } } - eventJSON, err := json.Marshal(filteredEvents) - if err != nil { - return nil, errors.Wrapf(err, "error marshalling events for run: %s", run.UUID()) - } - r.Events = string(eventJSON) - - // write our results out - resultsJSON, err := json.Marshal(fr.Results()) - if err != nil { - return nil, errors.Wrapf(err, "error marshalling results for run: %s", run.UUID()) - } - r.Results = string(resultsJSON) // set our parent UUID if we have a parent if fr.Parent() != nil { @@ -1032,10 +230,10 @@ WHERE fs.contact_id = ANY($2) ` -// RunExpiration looks up the run expiration for the passed in run, can return nil if the run is no longer active +// RunExpiration looks up the run expiration for the passed in run, can return nil if the run is no longer waiting func RunExpiration(ctx context.Context, db *sqlx.DB, runID FlowRunID) (*time.Time, error) { var expiration time.Time - err := db.Get(&expiration, `SELECT expires_on FROM flows_flowrun WHERE id = $1 AND is_active = TRUE`, runID) + err := db.Get(&expiration, `SELECT expires_on FROM flows_flowrun WHERE id = $1 AND status = 'W'`, runID) if err == sql.ErrNoRows { return nil, nil } @@ -1044,172 +242,3 @@ func RunExpiration(ctx context.Context, db *sqlx.DB, runID FlowRunID) (*time.Tim } return &expiration, nil } - -// ExitSessions marks the passed in sessions as completed, also doing so for all associated runs -func ExitSessions(ctx context.Context, tx Queryer, sessionIDs []SessionID, exitType ExitType, now time.Time) error { - if len(sessionIDs) == 0 { - return nil - } - - // map exit type to statuses for sessions and runs - sessionStatus := exitToSessionStatusMap[exitType] - runStatus, found := exitToRunStatusMap[exitType] - if !found { - return errors.Errorf("unknown exit type: %s", exitType) - } - - // first interrupt our runs - start := time.Now() - res, err := tx.ExecContext(ctx, exitSessionRunsSQL, pq.Array(sessionIDs), exitType, now, runStatus) - if err != nil { - return errors.Wrapf(err, "error exiting session runs") - } - rows, _ := res.RowsAffected() - logrus.WithField("count", rows).WithField("elapsed", time.Since(start)).Debug("exited session runs") - - // then our sessions - start = time.Now() - - res, err = tx.ExecContext(ctx, exitSessionsSQL, pq.Array(sessionIDs), now, sessionStatus) - if err != nil { - return errors.Wrapf(err, "error exiting sessions") - } - rows, _ = res.RowsAffected() - logrus.WithField("count", rows).WithField("elapsed", time.Since(start)).Debug("exited sessions") - - return nil -} - -const exitSessionRunsSQL = ` -UPDATE - flows_flowrun -SET - is_active = FALSE, - exit_type = $2, - exited_on = $3, - status = $4, - modified_on = NOW() -WHERE - id = ANY (SELECT id FROM flows_flowrun WHERE session_id = ANY($1) AND is_active = TRUE) -` - -const exitSessionsSQL = ` -UPDATE - flows_flowsession -SET - ended_on = $2, - status = $3 -WHERE - id = ANY ($1) AND - status = 'W' -` - -// InterruptContactRuns interrupts all runs and sesions that exist for the passed in list of contacts -func InterruptContactRuns(ctx context.Context, tx Queryer, sessionType FlowType, contactIDs []flows.ContactID, now time.Time) error { - if len(contactIDs) == 0 { - return nil - } - - // first interrupt our runs - err := Exec(ctx, "interrupting contact runs", tx, interruptContactRunsSQL, sessionType, pq.Array(contactIDs), now) - if err != nil { - return err - } - - err = Exec(ctx, "interrupting contact sessions", tx, interruptContactSessionsSQL, sessionType, pq.Array(contactIDs), now) - if err != nil { - return err - } - - return nil -} - -const interruptContactRunsSQL = ` -UPDATE - flows_flowrun -SET - is_active = FALSE, - exited_on = $3, - exit_type = 'I', - status = 'I', - modified_on = NOW() -WHERE - id = ANY ( - SELECT - fr.id - FROM - flows_flowrun fr - JOIN flows_flow ff ON fr.flow_id = ff.id - WHERE - fr.contact_id = ANY($2) AND - fr.is_active = TRUE AND - ff.flow_type = $1 - ) -` - -const interruptContactSessionsSQL = ` -UPDATE - flows_flowsession -SET - status = 'I', - ended_on = $3 -WHERE - id = ANY (SELECT id FROM flows_flowsession WHERE session_type = $1 AND contact_id = ANY($2) AND status = 'W') -` - -// ExpireRunsAndSessions expires all the passed in runs and sessions. Note this should only be called -// for runs that have no parents or no way of continuing -func ExpireRunsAndSessions(ctx context.Context, db *sqlx.DB, runIDs []FlowRunID, sessionIDs []SessionID) error { - if len(runIDs) == 0 { - return nil - } - - tx, err := db.BeginTxx(ctx, nil) - if err != nil { - return errors.Wrapf(err, "error starting transaction to expire sessions") - } - - err = Exec(ctx, "expiring runs", tx, expireRunsSQL, pq.Array(runIDs)) - if err != nil { - tx.Rollback() - return errors.Wrapf(err, "error expiring runs") - } - - if len(sessionIDs) > 0 { - err = Exec(ctx, "expiring sessions", tx, expireSessionsSQL, pq.Array(sessionIDs)) - if err != nil { - tx.Rollback() - return errors.Wrapf(err, "error expiring sessions") - } - } - - err = tx.Commit() - if err != nil { - return errors.Wrapf(err, "error committing expiration of runs and sessions") - } - return nil -} - -const expireSessionsSQL = ` - UPDATE - flows_flowsession s - SET - timeout_on = NULL, - ended_on = NOW(), - status = 'X' - WHERE - id = ANY($1) -` - -const expireRunsSQL = ` - UPDATE - flows_flowrun fr - SET - is_active = FALSE, - exited_on = NOW(), - exit_type = 'E', - status = 'E', - modified_on = NOW() - WHERE - id = ANY($1) -` diff --git a/core/models/schedules.go b/core/models/schedules.go index b3dfe918a..7eb6eae81 100644 --- a/core/models/schedules.go +++ b/core/models/schedules.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "github.com/nyaruka/mailroom/utils/dbutil" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/null" "github.com/pkg/errors" @@ -306,7 +306,7 @@ func GetUnfiredSchedules(ctx context.Context, db Queryer) ([]*Schedule, error) { unfired := make([]*Schedule, 0, 10) for rows.Next() { s := &Schedule{} - err := dbutil.ReadJSONRow(rows, &s.s) + err := dbutil.ScanJSON(rows, &s.s) if err != nil { return nil, errors.Wrapf(err, "error reading schedule") } diff --git a/core/models/search.go b/core/models/search.go index 1b6c849b8..4a7ab839f 100644 --- a/core/models/search.go +++ b/core/models/search.go @@ -17,10 +17,10 @@ import ( ) // BuildElasticQuery turns the passed in contact ql query into an elastic query -func BuildElasticQuery(org *OrgAssets, group assets.GroupUUID, status ContactStatus, excludeIDs []ContactID, query *contactql.ContactQuery) elastic.Query { +func BuildElasticQuery(oa *OrgAssets, group assets.GroupUUID, status ContactStatus, excludeIDs []ContactID, query *contactql.ContactQuery) elastic.Query { // filter by org and active contacts eq := elastic.NewBoolQuery().Must( - elastic.NewTermQuery("org_id", org.OrgID()), + elastic.NewTermQuery("org_id", oa.OrgID()), elastic.NewTermQuery("is_active", true), ) @@ -45,7 +45,7 @@ func BuildElasticQuery(org *OrgAssets, group assets.GroupUUID, status ContactSta // and by our query if present if query != nil { - q := es.ToElasticQuery(org.Env(), query) + q := es.ToElasticQuery(oa.Env(), query) eq = eq.Must(q) } @@ -53,8 +53,8 @@ func BuildElasticQuery(org *OrgAssets, group assets.GroupUUID, status ContactSta } // ContactIDsForQueryPage returns the ids of the contacts for the passed in query page -func ContactIDsForQueryPage(ctx context.Context, client *elastic.Client, org *OrgAssets, group assets.GroupUUID, excludeIDs []ContactID, query string, sort string, offset int, pageSize int) (*contactql.ContactQuery, []ContactID, int64, error) { - env := org.Env() +func ContactIDsForQueryPage(ctx context.Context, client *elastic.Client, oa *OrgAssets, group assets.GroupUUID, excludeIDs []ContactID, query string, sort string, offset int, pageSize int) (*contactql.ContactQuery, []ContactID, int64, error) { + env := oa.Env() start := time.Now() var parsed *contactql.ContactQuery var err error @@ -64,20 +64,20 @@ func ContactIDsForQueryPage(ctx context.Context, client *elastic.Client, org *Or } if query != "" { - parsed, err = contactql.ParseQuery(env, query, org.SessionAssets()) + parsed, err = contactql.ParseQuery(env, query, oa.SessionAssets()) if err != nil { return nil, nil, 0, errors.Wrapf(err, "error parsing query: %s", query) } } - eq := BuildElasticQuery(org, group, NilContactStatus, excludeIDs, parsed) + eq := BuildElasticQuery(oa, group, NilContactStatus, excludeIDs, parsed) - fieldSort, err := es.ToElasticFieldSort(sort, org.SessionAssets()) + fieldSort, err := es.ToElasticFieldSort(sort, oa.SessionAssets()) if err != nil { return nil, nil, 0, errors.Wrapf(err, "error parsing sort") } - s := client.Search("contacts").TrackTotalHits(true).Routing(strconv.FormatInt(int64(org.OrgID()), 10)) + s := client.Search("contacts").TrackTotalHits(true).Routing(strconv.FormatInt(int64(oa.OrgID()), 10)) s = s.Size(pageSize).From(offset).Query(eq).SortBy(fieldSort).FetchSource(false) results, err := s.Do(ctx) @@ -101,7 +101,7 @@ func ContactIDsForQueryPage(ctx context.Context, client *elastic.Client, org *Or } logrus.WithFields(logrus.Fields{ - "org_id": org.OrgID(), + "org_id": oa.OrgID(), "parsed": parsed, "group_uuid": group, "query": query, @@ -114,8 +114,8 @@ func ContactIDsForQueryPage(ctx context.Context, client *elastic.Client, org *Or } // ContactIDsForQuery returns the ids of all the contacts that match the passed in query -func ContactIDsForQuery(ctx context.Context, client *elastic.Client, org *OrgAssets, query string) ([]ContactID, error) { - env := org.Env() +func ContactIDsForQuery(ctx context.Context, client *elastic.Client, oa *OrgAssets, query string) ([]ContactID, error) { + env := oa.Env() start := time.Now() if client == nil { @@ -123,23 +123,23 @@ func ContactIDsForQuery(ctx context.Context, client *elastic.Client, org *OrgAss } // turn into elastic query - parsed, err := contactql.ParseQuery(env, query, org.SessionAssets()) + parsed, err := contactql.ParseQuery(env, query, oa.SessionAssets()) if err != nil { return nil, errors.Wrapf(err, "error parsing query: %s", query) } - eq := BuildElasticQuery(org, "", ContactStatusActive, nil, parsed) + eq := BuildElasticQuery(oa, "", ContactStatusActive, nil, parsed) ids := make([]ContactID, 0, 100) // iterate across our results, building up our contact ids - scroll := client.Scroll("contacts").Routing(strconv.FormatInt(int64(org.OrgID()), 10)) + scroll := client.Scroll("contacts").Routing(strconv.FormatInt(int64(oa.OrgID()), 10)) scroll = scroll.KeepAlive("15m").Size(10000).Query(eq).FetchSource(false) for { results, err := scroll.Do(ctx) if err == io.EOF { logrus.WithFields(logrus.Fields{ - "org_id": org.OrgID(), + "org_id": oa.OrgID(), "query": query, "elapsed": time.Since(start), "match_count": len(ids), diff --git a/core/models/sessions.go b/core/models/sessions.go new file mode 100644 index 000000000..b2ef40984 --- /dev/null +++ b/core/models/sessions.go @@ -0,0 +1,965 @@ +package models + +import ( + "context" + "crypto/md5" + "encoding/json" + "fmt" + "net/url" + "path" + "time" + + "github.com/aws/aws-sdk-go/service/s3" + "github.com/gomodule/redigo/redis" + "github.com/jmoiron/sqlx" + "github.com/lib/pq" + "github.com/nyaruka/gocommon/storage" + "github.com/nyaruka/gocommon/uuids" + "github.com/nyaruka/goflow/assets" + "github.com/nyaruka/goflow/envs" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/flows/events" + "github.com/nyaruka/mailroom/core/goflow" + "github.com/nyaruka/mailroom/runtime" + "github.com/nyaruka/null" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +type SessionID int64 +type SessionStatus string + +const ( + SessionStatusWaiting SessionStatus = "W" + SessionStatusCompleted SessionStatus = "C" + SessionStatusExpired SessionStatus = "X" + SessionStatusInterrupted SessionStatus = "I" + SessionStatusFailed SessionStatus = "F" + + storageTSFormat = "20060102T150405.999Z" +) + +var sessionStatusMap = map[flows.SessionStatus]SessionStatus{ + flows.SessionStatusWaiting: SessionStatusWaiting, + flows.SessionStatusCompleted: SessionStatusCompleted, + flows.SessionStatusFailed: SessionStatusFailed, +} + +type SessionCommitHook func(context.Context, *sqlx.Tx, *redis.Pool, *OrgAssets, []*Session) error + +// Session is the mailroom type for a FlowSession +type Session struct { + s struct { + ID SessionID `db:"id"` + UUID flows.SessionUUID `db:"uuid"` + SessionType FlowType `db:"session_type"` + Status SessionStatus `db:"status"` + Responded bool `db:"responded"` + Output null.String `db:"output"` + OutputURL null.String `db:"output_url"` + ContactID ContactID `db:"contact_id"` + OrgID OrgID `db:"org_id"` + CreatedOn time.Time `db:"created_on"` + EndedOn *time.Time `db:"ended_on"` + WaitStartedOn *time.Time `db:"wait_started_on"` + WaitTimeoutOn *time.Time `db:"timeout_on"` + WaitExpiresOn *time.Time `db:"wait_expires_on"` + WaitResumeOnExpire *bool `db:"wait_resume_on_expire"` + CurrentFlowID FlowID `db:"current_flow_id"` + ConnectionID *ConnectionID `db:"connection_id"` + } + + incomingMsgID MsgID + incomingExternalID null.String + + // any channel connection associated with this flow session + channelConnection *ChannelConnection + + // time after our last message is sent that we should timeout + timeout *time.Duration + + contact *flows.Contact + runs []*FlowRun + + seenRuns map[flows.RunUUID]time.Time + + // we keep around a reference to the sprint associated with this session + sprint flows.Sprint + + // the scene for our event hooks + scene *Scene + + findStep func(flows.StepUUID) (flows.Run, flows.Step) +} + +func (s *Session) ID() SessionID { return s.s.ID } +func (s *Session) UUID() flows.SessionUUID { return flows.SessionUUID(s.s.UUID) } +func (s *Session) SessionType() FlowType { return s.s.SessionType } +func (s *Session) Status() SessionStatus { return s.s.Status } +func (s *Session) Responded() bool { return s.s.Responded } +func (s *Session) Output() string { return string(s.s.Output) } +func (s *Session) OutputURL() string { return string(s.s.OutputURL) } +func (s *Session) ContactID() ContactID { return s.s.ContactID } +func (s *Session) OrgID() OrgID { return s.s.OrgID } +func (s *Session) CreatedOn() time.Time { return s.s.CreatedOn } +func (s *Session) EndedOn() *time.Time { return s.s.EndedOn } +func (s *Session) WaitStartedOn() *time.Time { return s.s.WaitStartedOn } +func (s *Session) WaitTimeoutOn() *time.Time { return s.s.WaitTimeoutOn } +func (s *Session) WaitExpiresOn() *time.Time { return s.s.WaitExpiresOn } +func (s *Session) WaitResumeOnExpire() *bool { return s.s.WaitResumeOnExpire } +func (s *Session) ClearTimeoutOn() { s.s.WaitTimeoutOn = nil } +func (s *Session) CurrentFlowID() FlowID { return s.s.CurrentFlowID } +func (s *Session) ConnectionID() *ConnectionID { return s.s.ConnectionID } +func (s *Session) IncomingMsgID() MsgID { return s.incomingMsgID } +func (s *Session) IncomingMsgExternalID() null.String { return s.incomingExternalID } +func (s *Session) Scene() *Scene { return s.scene } + +// StoragePath returns the path for the session +func (s *Session) StoragePath(cfg *runtime.Config) string { + ts := s.CreatedOn().UTC().Format(storageTSFormat) + + // example output: /orgs/1/c/20a5/20a5534c-b2ad-4f18-973a-f1aa3b4e6c74/session_20060102T150405.123Z_8a7fc501-177b-4567-a0aa-81c48e6de1c5_51df83ac21d3cf136d8341f0b11cb1a7.json" + return path.Join( + cfg.S3SessionPrefix, + "orgs", + fmt.Sprintf("%d", s.OrgID()), + "c", + string(s.ContactUUID()[:4]), + string(s.ContactUUID()), + fmt.Sprintf("%s_session_%s_%s.json", ts, s.UUID(), s.OutputMD5()), + ) +} + +// ContactUUID returns the UUID of our contact +func (s *Session) ContactUUID() flows.ContactUUID { + return s.contact.UUID() +} + +// Contact returns the contact for this session +func (s *Session) Contact() *flows.Contact { + return s.contact +} + +// Runs returns our flow run +func (s *Session) Runs() []*FlowRun { + return s.runs +} + +// Sprint returns the sprint associated with this session +func (s *Session) Sprint() flows.Sprint { + return s.sprint +} + +// FindStep finds the run and step with the given UUID +func (s *Session) FindStep(uuid flows.StepUUID) (flows.Run, flows.Step) { + return s.findStep(uuid) +} + +// Timeout returns the amount of time after our last message sends that we should timeout +func (s *Session) Timeout() *time.Duration { + return s.timeout +} + +// OutputMD5 returns the md5 of the passed in session +func (s *Session) OutputMD5() string { + return fmt.Sprintf("%x", md5.Sum([]byte(s.s.Output))) +} + +// SetIncomingMsg set the incoming message that this session should be associated with in this sprint +func (s *Session) SetIncomingMsg(id flows.MsgID, externalID null.String) { + s.incomingMsgID = MsgID(id) + s.incomingExternalID = externalID +} + +// SetChannelConnection sets the channel connection associated with this sprint +func (s *Session) SetChannelConnection(cc *ChannelConnection) { + connID := cc.ID() + s.s.ConnectionID = &connID + s.channelConnection = cc + + // also set it on all our runs + for _, r := range s.runs { + r.SetConnectionID(&connID) + } +} + +func (s *Session) ChannelConnection() *ChannelConnection { + return s.channelConnection +} + +// FlowSession creates a flow session for the passed in session object. It also populates the runs we know about +func (s *Session) FlowSession(cfg *runtime.Config, sa flows.SessionAssets, env envs.Environment) (flows.Session, error) { + session, err := goflow.Engine(cfg).ReadSession(sa, json.RawMessage(s.s.Output), assets.IgnoreMissing) + if err != nil { + return nil, errors.Wrapf(err, "unable to unmarshal session") + } + + // walk through our session, populate seen runs + s.seenRuns = make(map[flows.RunUUID]time.Time, len(session.Runs())) + for _, r := range session.Runs() { + s.seenRuns[r.UUID()] = r.ModifiedOn() + } + + return session, nil +} + +// looks for a wait event and updates wait fields if one exists +func (s *Session) updateWait(evts []flows.Event) { + boolPtr := func(b bool) *bool { return &b } + + s.s.WaitStartedOn = nil + s.s.WaitTimeoutOn = nil + s.s.WaitExpiresOn = nil + s.s.WaitResumeOnExpire = boolPtr(false) + s.timeout = nil + + now := time.Now() + + for _, e := range evts { + switch typed := e.(type) { + case *events.MsgWaitEvent: + run, _ := s.findStep(e.StepUUID()) + + s.s.WaitStartedOn = &now + s.s.WaitExpiresOn = typed.ExpiresOn + s.s.WaitResumeOnExpire = boolPtr(run.ParentInSession() != nil) + + if typed.TimeoutSeconds != nil { + seconds := time.Duration(*typed.TimeoutSeconds) * time.Second + timeoutOn := now.Add(seconds) + + s.s.WaitTimeoutOn = &timeoutOn + s.timeout = &seconds + } + case *events.DialWaitEvent: + run, _ := s.findStep(e.StepUUID()) + + s.s.WaitStartedOn = &now + s.s.WaitExpiresOn = typed.ExpiresOn + s.s.WaitResumeOnExpire = boolPtr(run.ParentInSession() != nil) + } + } +} + +const sqlUpdateSession = ` +UPDATE + flows_flowsession +SET + output = :output, + output_url = :output_url, + status = :status, + ended_on = :ended_on, + responded = :responded, + current_flow_id = :current_flow_id, + wait_started_on = :wait_started_on, + wait_expires_on = :wait_expires_on, + wait_resume_on_expire = :wait_resume_on_expire, + timeout_on = :timeout_on +WHERE + id = :id +` + +const sqlUpdateSessionNoOutput = ` +UPDATE + flows_flowsession +SET + output_url = :output_url, + status = :status, + ended_on = :ended_on, + responded = :responded, + current_flow_id = :current_flow_id, + wait_started_on = :wait_started_on, + wait_expires_on = :wait_expires_on, + wait_resume_on_expire = :wait_resume_on_expire, + timeout_on = :timeout_on +WHERE + id = :id +` + +const sqlUpdateRun = ` +UPDATE + flows_flowrun fr +SET + is_active = r.is_active::bool, + exit_type = r.exit_type, + status = r.status, + exited_on = r.exited_on::timestamp with time zone, + expires_on = r.expires_on::timestamp with time zone, + responded = r.responded::bool, + results = r.results, + path = r.path::jsonb, + current_node_uuid = r.current_node_uuid::uuid, + modified_on = NOW() +FROM ( + VALUES(:uuid, :is_active, :exit_type, :status, :exited_on, :expires_on, :responded, :results, :path, :current_node_uuid) +) AS + r(uuid, is_active, exit_type, status, exited_on, expires_on, responded, results, path, current_node_uuid) +WHERE + fr.uuid = r.uuid::uuid +` + +// Update updates the session based on the state passed in from our engine session, this also takes care of applying any event hooks +func (s *Session) Update(ctx context.Context, rt *runtime.Runtime, tx *sqlx.Tx, oa *OrgAssets, fs flows.Session, sprint flows.Sprint, hook SessionCommitHook) error { + // make sure we have our seen runs + if s.seenRuns == nil { + return errors.Errorf("missing seen runs, cannot update session") + } + + output, err := json.Marshal(fs) + if err != nil { + return errors.Wrapf(err, "error marshalling flow session") + } + s.s.Output = null.String(output) + + // map our status over + status, found := sessionStatusMap[fs.Status()] + if !found { + return errors.Errorf("unknown session status: %s", fs.Status()) + } + s.s.Status = status + + if s.s.Status != SessionStatusWaiting { + now := time.Now() + s.s.EndedOn = &now + } + + // now build up our runs + for _, r := range fs.Runs() { + run, err := newRun(ctx, tx, oa, s, r) + if err != nil { + return errors.Wrapf(err, "error creating run: %s", r.UUID()) + } + + // set the run on our session + s.runs = append(s.runs, run) + } + + // set our sprint, wait and step finder + s.sprint = sprint + s.findStep = fs.FindStep + s.s.CurrentFlowID = NilFlowID + + // update wait related fields + s.updateWait(sprint.Events()) + + // run through our runs to figure out our current flow + for _, r := range fs.Runs() { + // if this run is waiting, save it as the current flow + if r.Status() == flows.RunStatusWaiting { + flowID, err := FlowIDForUUID(ctx, tx, oa, r.FlowReference().UUID) + if err != nil { + return errors.Wrapf(err, "error loading flow: %s", r.FlowReference().UUID) + } + s.s.CurrentFlowID = flowID + } + + // if we haven't already been marked as responded, walk our runs looking for an input + if !s.s.Responded { + // run through events, see if any are received events + for _, e := range r.Events() { + if e.Type() == events.TypeMsgReceived { + s.s.Responded = true + break + } + } + } + } + + // apply all our pre write events + for _, e := range sprint.Events() { + err := ApplyPreWriteEvent(ctx, rt, tx, oa, s.scene, e) + if err != nil { + return errors.Wrapf(err, "error applying event: %v", e) + } + } + + // the SQL statement we'll use to update this session + updateSQL := sqlUpdateSession + + // if writing to S3, do so + if rt.Config.SessionStorage == "s3" { + err := WriteSessionOutputsToStorage(ctx, rt, []*Session{s}) + if err != nil { + logrus.WithError(err).Error("error writing session to s3") + } + + // don't write output in our SQL + updateSQL = sqlUpdateSessionNoOutput + } + + // write our new session state to the db + _, err = tx.NamedExecContext(ctx, updateSQL, s.s) + if err != nil { + return errors.Wrapf(err, "error updating session") + } + + // if this session is complete, so is any associated connection + if s.channelConnection != nil { + if s.Status() == SessionStatusCompleted || s.Status() == SessionStatusFailed { + err := s.channelConnection.UpdateStatus(ctx, tx, ConnectionStatusCompleted, 0, time.Now()) + if err != nil { + return errors.Wrapf(err, "error update channel connection") + } + } + } + + // figure out which runs are new and which are updated + updatedRuns := make([]interface{}, 0, 1) + newRuns := make([]interface{}, 0) + for _, r := range s.Runs() { + modified, found := s.seenRuns[r.UUID()] + if !found { + newRuns = append(newRuns, &r.r) + continue + } + + if r.ModifiedOn().After(modified) { + updatedRuns = append(updatedRuns, &r.r) + continue + } + } + + // call our global pre commit hook if present + if hook != nil { + err := hook(ctx, tx, rt.RP, oa, []*Session{s}) + if err != nil { + return errors.Wrapf(err, "error calling commit hook: %v", hook) + } + } + + // update all modified runs at once + err = BulkQuery(ctx, "update runs", tx, sqlUpdateRun, updatedRuns) + if err != nil { + logrus.WithError(err).WithField("session", string(output)).Error("error while updating runs for session") + return errors.Wrapf(err, "error updating runs") + } + + // insert all new runs at once + err = BulkQuery(ctx, "insert runs", tx, insertRunSQL, newRuns) + if err != nil { + return errors.Wrapf(err, "error writing runs") + } + + if err := RecordFlowStatistics(ctx, rt, tx, []flows.Session{fs}, []flows.Sprint{sprint}); err != nil { + return errors.Wrapf(err, "error saving flow statistics") + } + + // apply all our events + if s.Status() != SessionStatusFailed { + err = HandleEvents(ctx, rt, tx, oa, s.scene, sprint.Events()) + if err != nil { + return errors.Wrapf(err, "error applying events: %d", s.ID()) + } + } + + // gather all our pre commit events, group them by hook and apply them + err = ApplyEventPreCommitHooks(ctx, rt, tx, oa, []*Scene{s.scene}) + if err != nil { + return errors.Wrapf(err, "error applying pre commit hook: %T", hook) + } + + return nil +} + +// MarshalJSON is our custom marshaller so that our inner struct get output +func (s *Session) MarshalJSON() ([]byte, error) { + return json.Marshal(s.s) +} + +// UnmarshalJSON is our custom marshaller so that our inner struct get output +func (s *Session) UnmarshalJSON(b []byte) error { + return json.Unmarshal(b, &s.s) +} + +// NewSession a session objects from the passed in flow session. It does NOT +// commit said session to the database. +func NewSession(ctx context.Context, tx *sqlx.Tx, oa *OrgAssets, fs flows.Session, sprint flows.Sprint) (*Session, error) { + output, err := json.Marshal(fs) + if err != nil { + return nil, errors.Wrapf(err, "error marshalling flow session") + } + + // map our status over + sessionStatus, found := sessionStatusMap[fs.Status()] + if !found { + return nil, errors.Errorf("unknown session status: %s", fs.Status()) + } + + // session must have at least one run + if len(fs.Runs()) < 1 { + return nil, errors.Errorf("cannot write session that has no runs") + } + + // figure out our type + sessionType, found := flowTypeMapping[fs.Type()] + if !found { + return nil, errors.Errorf("unknown flow type: %s", fs.Type()) + } + + uuid := fs.UUID() + if uuid == "" { + uuid = flows.SessionUUID(uuids.New()) + } + + // create our session object + session := &Session{} + s := &session.s + s.UUID = uuid + s.Status = sessionStatus + s.SessionType = sessionType + s.Responded = false + s.Output = null.String(output) + s.ContactID = ContactID(fs.Contact().ID()) + s.OrgID = oa.OrgID() + s.CreatedOn = fs.Runs()[0].CreatedOn() + + if s.Status != SessionStatusWaiting { + now := time.Now() + s.EndedOn = &now + } + + session.contact = fs.Contact() + session.scene = NewSceneForSession(session) + + session.sprint = sprint + session.findStep = fs.FindStep + + // now build up our runs + for _, r := range fs.Runs() { + run, err := newRun(ctx, tx, oa, session, r) + if err != nil { + return nil, errors.Wrapf(err, "error creating run: %s", r.UUID()) + } + + // save the run to our session + session.runs = append(session.runs, run) + + // if this run is waiting, save it as the current flow + if r.Status() == flows.RunStatusWaiting { + flowID, err := FlowIDForUUID(ctx, tx, oa, r.FlowReference().UUID) + if err != nil { + return nil, errors.Wrapf(err, "error loading current flow for UUID: %s", r.FlowReference().UUID) + } + s.CurrentFlowID = flowID + } + } + + // calculate our timeout if any + session.updateWait(sprint.Events()) + + return session, nil +} + +const sqlInsertCompleteSession = ` +INSERT INTO + flows_flowsession( uuid, session_type, status, responded, output, output_url, contact_id, org_id, created_on, ended_on, wait_resume_on_expire, connection_id) + VALUES(:uuid, :session_type, :status, :responded, :output, :output_url, :contact_id, :org_id, NOW(), NOW(), FALSE, :connection_id) +RETURNING id +` + +const sqlInsertIncompleteSession = ` +INSERT INTO + flows_flowsession( uuid, session_type, status, responded, output, output_url, contact_id, org_id, created_on, current_flow_id, timeout_on, wait_started_on, wait_expires_on, wait_resume_on_expire, connection_id) + VALUES(:uuid,:session_type,:status,:responded, :output, :output_url, :contact_id, :org_id, NOW(), :current_flow_id, :timeout_on, :wait_started_on, :wait_expires_on, :wait_resume_on_expire, :connection_id) +RETURNING id +` + +const sqlInsertCompleteSessionNoOutput = ` +INSERT INTO + flows_flowsession( uuid, session_type, status, responded, output_url, contact_id, org_id, created_on, ended_on, wait_resume_on_expire, connection_id) + VALUES(:uuid,:session_type,:status,:responded, :output_url, :contact_id, :org_id, NOW(), NOW(), FALSE, :connection_id) +RETURNING id +` + +const sqlInsertIncompleteSessionNoOutput = ` +INSERT INTO + flows_flowsession( uuid, session_type, status, responded, output_url, contact_id, org_id, created_on, current_flow_id, timeout_on, wait_started_on, wait_expires_on, wait_resume_on_expire, connection_id) + VALUES(:uuid,:session_type,:status,:responded, :output_url, :contact_id, :org_id, NOW(), :current_flow_id, :timeout_on, :wait_started_on, :wait_expires_on, :wait_resume_on_expire, :connection_id) +RETURNING id +` + +// WriteSessions writes the passed in session to our database, writes any runs that need to be created +// as well as appying any events created in the session +func WriteSessions(ctx context.Context, rt *runtime.Runtime, tx *sqlx.Tx, oa *OrgAssets, ss []flows.Session, sprints []flows.Sprint, hook SessionCommitHook) ([]*Session, error) { + if len(ss) == 0 { + return nil, nil + } + + // create all our session objects + sessions := make([]*Session, 0, len(ss)) + completeSessionsI := make([]interface{}, 0, len(ss)) + incompleteSessionsI := make([]interface{}, 0, len(ss)) + completedConnectionIDs := make([]ConnectionID, 0, 1) + for i, s := range ss { + session, err := NewSession(ctx, tx, oa, s, sprints[i]) + if err != nil { + return nil, errors.Wrapf(err, "error creating session objects") + } + sessions = append(sessions, session) + + if session.Status() == SessionStatusCompleted { + completeSessionsI = append(completeSessionsI, &session.s) + if session.channelConnection != nil { + completedConnectionIDs = append(completedConnectionIDs, session.channelConnection.ID()) + } + } else { + incompleteSessionsI = append(incompleteSessionsI, &session.s) + } + } + + // apply all our pre write events + for i := range ss { + for _, e := range sprints[i].Events() { + err := ApplyPreWriteEvent(ctx, rt, tx, oa, sessions[i].scene, e) + if err != nil { + return nil, errors.Wrapf(err, "error applying event: %v", e) + } + } + } + + // call our global pre commit hook if present + if hook != nil { + err := hook(ctx, tx, rt.RP, oa, sessions) + if err != nil { + return nil, errors.Wrapf(err, "error calling commit hook: %v", hook) + } + } + + // the SQL we'll use to do our insert of complete sessions + insertCompleteSQL := sqlInsertCompleteSession + insertIncompleteSQL := sqlInsertIncompleteSession + + // if writing our sessions to S3, do so + if rt.Config.SessionStorage == "s3" { + err := WriteSessionOutputsToStorage(ctx, rt, sessions) + if err != nil { + // for now, continue on for errors, we are still reading from the DB + logrus.WithError(err).Error("error writing sessions to s3") + } + + insertCompleteSQL = sqlInsertCompleteSessionNoOutput + insertIncompleteSQL = sqlInsertIncompleteSessionNoOutput + } + + // insert our complete sessions first + err := BulkQuery(ctx, "insert completed sessions", tx, insertCompleteSQL, completeSessionsI) + if err != nil { + return nil, errors.Wrapf(err, "error inserting completed sessions") + } + + // mark any connections that are done as complete as well + err = UpdateChannelConnectionStatuses(ctx, tx, completedConnectionIDs, ConnectionStatusCompleted) + if err != nil { + return nil, errors.Wrapf(err, "error updating channel connections to complete") + } + + // insert incomplete sessions + err = BulkQuery(ctx, "insert incomplete sessions", tx, insertIncompleteSQL, incompleteSessionsI) + if err != nil { + return nil, errors.Wrapf(err, "error inserting incomplete sessions") + } + + // for each session associate our run with each + runs := make([]interface{}, 0, len(sessions)) + for _, s := range sessions { + for _, r := range s.runs { + runs = append(runs, &r.r) + + // set our session id now that it is written + r.SetSessionID(s.ID()) + } + } + + // insert all runs + err = BulkQuery(ctx, "insert runs", tx, insertRunSQL, runs) + if err != nil { + return nil, errors.Wrapf(err, "error writing runs") + } + + if err := RecordFlowStatistics(ctx, rt, tx, ss, sprints); err != nil { + return nil, errors.Wrapf(err, "error saving flow statistics") + } + + // apply our all events for the session + scenes := make([]*Scene, 0, len(ss)) + for i := range sessions { + if ss[i].Status() == flows.SessionStatusFailed { + continue + } + + err = HandleEvents(ctx, rt, tx, oa, sessions[i].Scene(), sprints[i].Events()) + if err != nil { + return nil, errors.Wrapf(err, "error applying events for session: %d", sessions[i].ID()) + } + + scene := sessions[i].Scene() + scenes = append(scenes, scene) + } + + // gather all our pre commit events, group them by hook + err = ApplyEventPreCommitHooks(ctx, rt, tx, oa, scenes) + if err != nil { + return nil, errors.Wrapf(err, "error applying pre commit hook: %T", hook) + } + + // return our session + return sessions, nil +} + +const sqlSelectWaitingSessionForContact = ` +SELECT + id, + uuid, + session_type, + status, + responded, + output, + output_url, + contact_id, + org_id, + created_on, + ended_on, + timeout_on, + wait_started_on, + wait_expires_on, + wait_resume_on_expire, + current_flow_id, + connection_id +FROM + flows_flowsession fs +WHERE + session_type = $1 AND + contact_id = $2 AND + status = 'W' +ORDER BY + created_on DESC +LIMIT 1 +` + +// FindWaitingSessionForContact returns the waiting session for the passed in contact, if any +func FindWaitingSessionForContact(ctx context.Context, db *sqlx.DB, st storage.Storage, oa *OrgAssets, sessionType FlowType, contact *flows.Contact) (*Session, error) { + rows, err := db.QueryxContext(ctx, sqlSelectWaitingSessionForContact, sessionType, contact.ID()) + if err != nil { + return nil, errors.Wrapf(err, "error selecting waiting session") + } + defer rows.Close() + + // no rows? no sessions! + if !rows.Next() { + return nil, nil + } + + // scan in our session + session := &Session{ + contact: contact, + } + session.scene = NewSceneForSession(session) + + if err := rows.StructScan(&session.s); err != nil { + return nil, errors.Wrapf(err, "error scanning session") + } + + // load our output from storage if necessary + if session.OutputURL() != "" { + // strip just the path out of our output URL + u, err := url.Parse(session.OutputURL()) + if err != nil { + return nil, errors.Wrapf(err, "error parsing output URL: %s", session.OutputURL()) + } + + start := time.Now() + + _, output, err := st.Get(ctx, u.Path) + if err != nil { + return nil, errors.Wrapf(err, "error reading session from storage: %s", session.OutputURL()) + } + + logrus.WithField("elapsed", time.Since(start)).WithField("output_url", session.OutputURL()).Debug("loaded session from storage") + session.s.Output = null.String(output) + } + + return session, nil +} + +// WriteSessionsToStorage writes the outputs of the passed in sessions to our storage (S3), updating the +// output_url for each on success. Failure of any will cause all to fail. +func WriteSessionOutputsToStorage(ctx context.Context, rt *runtime.Runtime, sessions []*Session) error { + start := time.Now() + + uploads := make([]*storage.Upload, len(sessions)) + for i, s := range sessions { + uploads[i] = &storage.Upload{ + Path: s.StoragePath(rt.Config), + Body: []byte(s.Output()), + ContentType: "application/json", + ACL: s3.ObjectCannedACLPrivate, + } + } + + err := rt.SessionStorage.BatchPut(ctx, uploads) + if err != nil { + return errors.Wrapf(err, "error writing sessions to storage") + } + + for i, s := range sessions { + s.s.OutputURL = null.String(uploads[i].URL) + } + + logrus.WithField("elapsed", time.Since(start)).WithField("count", len(sessions)).Debug("wrote sessions to s3") + return nil +} + +// ExitSessions exits sessions and their runs. It batches the given session ids and exits each batch in a transaction. +func ExitSessions(ctx context.Context, db *sqlx.DB, sessionIDs []SessionID, status SessionStatus) error { + if len(sessionIDs) == 0 { + return nil + } + + // split into batches and exit each batch in a transaction + for _, idBatch := range chunkSessionIDs(sessionIDs, 100) { + tx, err := db.BeginTxx(ctx, nil) + if err != nil { + return errors.Wrapf(err, "error starting transaction to exit sessions") + } + + if err := exitSessionBatch(ctx, tx, idBatch, status); err != nil { + return errors.Wrapf(err, "error exiting batch of sessions") + } + + if err := tx.Commit(); err != nil { + return errors.Wrapf(err, "error committing session exits") + } + } + + return nil +} + +const sqlExitSessions = ` + UPDATE flows_flowsession + SET status = $3, ended_on = $2, wait_started_on = NULL, wait_expires_on = NULL, timeout_on = NULL, current_flow_id = NULL + WHERE id = ANY ($1) AND status = 'W' +RETURNING contact_id` + +const sqlExitSessionRuns = ` +UPDATE flows_flowrun + SET is_active = FALSE, exit_type = $2, exited_on = $3, status = $4, modified_on = NOW() + WHERE id = ANY (SELECT id FROM flows_flowrun WHERE session_id = ANY($1) AND status IN ('A', 'W'))` + +const sqlExitSessionContacts = ` + UPDATE contacts_contact + SET current_flow_id = NULL, modified_on = NOW() + WHERE id = ANY($1)` + +// exits sessions and their runs inside the given transaction +func exitSessionBatch(ctx context.Context, tx *sqlx.Tx, sessionIDs []SessionID, status SessionStatus) error { + runStatus := RunStatus(status) // session status codes are subset of run status codes + exitType := runStatusToExitType[runStatus] // for compatibility + + contactIDs := make([]SessionID, 0, len(sessionIDs)) + + // first update the sessions themselves and get the contact ids + start := time.Now() + + err := tx.SelectContext(ctx, &contactIDs, sqlExitSessions, pq.Array(sessionIDs), time.Now(), status) + if err != nil { + return errors.Wrapf(err, "error exiting sessions") + } + + logrus.WithField("count", len(contactIDs)).WithField("elapsed", time.Since(start)).Debug("exited session batch") + + // then the runs that belong to these sessions + start = time.Now() + + res, err := tx.ExecContext(ctx, sqlExitSessionRuns, pq.Array(sessionIDs), exitType, time.Now(), runStatus) + if err != nil { + return errors.Wrapf(err, "error exiting session runs") + } + + rows, _ := res.RowsAffected() + logrus.WithField("count", rows).WithField("elapsed", time.Since(start)).Debug("exited session batch runs") + + // and finally the contacts from each session + start = time.Now() + + res, err = tx.ExecContext(ctx, sqlExitSessionContacts, pq.Array(contactIDs)) + if err != nil { + return errors.Wrapf(err, "error exiting sessions") + } + + rows, _ = res.RowsAffected() + logrus.WithField("count", rows).WithField("elapsed", time.Since(start)).Debug("exited session batch contacts") + + return nil +} + +// InterruptSessionsForContacts interrupts any waiting sessions for the given contacts +func InterruptSessionsForContacts(ctx context.Context, db *sqlx.DB, contactIDs []ContactID) error { + sessionIDs := make([]SessionID, 0, len(contactIDs)) + + err := db.SelectContext(ctx, &sessionIDs, `SELECT id FROM flows_flowsession WHERE status = 'W' AND contact_id = ANY($1)`, pq.Array(contactIDs)) + if err != nil { + return errors.Wrapf(err, "error selecting waiting sessions for contacts") + } + + return errors.Wrapf(ExitSessions(ctx, db, sessionIDs, SessionStatusInterrupted), "error exiting sessions") +} + +const sqlWaitingSessionIDsOfTypeForContacts = ` +SELECT id + FROM flows_flowsession + WHERE status = 'W' AND contact_id = ANY($1) AND session_type = $2;` + +// InterruptSessionsOfTypeForContacts interrupts any waiting sessions of the given type for the given contacts +func InterruptSessionsOfTypeForContacts(ctx context.Context, tx *sqlx.Tx, contactIDs []ContactID, sessionType FlowType) error { + sessionIDs := make([]SessionID, 0, len(contactIDs)) + + err := tx.SelectContext(ctx, &sessionIDs, sqlWaitingSessionIDsOfTypeForContacts, pq.Array(contactIDs), sessionType) + if err != nil { + return errors.Wrapf(err, "error selecting waiting sessions for contacts") + } + + return errors.Wrapf(exitSessionBatch(ctx, tx, sessionIDs, SessionStatusInterrupted), "error exiting sessions") +} + +const sqlWaitingSessionIDsForChannels = ` +SELECT fs.id + FROM flows_flowsession fs + JOIN channels_channelconnection cc ON fs.connection_id = cc.id + WHERE fs.status = 'W' AND cc.channel_id = ANY($1);` + +// InterruptSessionsForChannels interrupts any waiting sessions with connections on the given channels +func InterruptSessionsForChannels(ctx context.Context, db *sqlx.DB, channelIDs []ChannelID) error { + if len(channelIDs) == 0 { + return nil + } + + sessionIDs := make([]SessionID, 0, len(channelIDs)) + + err := db.SelectContext(ctx, &sessionIDs, sqlWaitingSessionIDsForChannels, pq.Array(channelIDs)) + if err != nil { + return errors.Wrapf(err, "error selecting waiting sessions for channels") + } + + return errors.Wrapf(ExitSessions(ctx, db, sessionIDs, SessionStatusInterrupted), "error exiting sessions") +} + +const sqlWaitingSessionIDsForFlows = ` +SELECT id + FROM flows_flowsession + WHERE status = 'W' AND current_flow_id = ANY($1);` + +// InterruptSessionsForFlows interrupts any waiting sessions currently in the given flows +func InterruptSessionsForFlows(ctx context.Context, db *sqlx.DB, flowIDs []FlowID) error { + if len(flowIDs) == 0 { + return nil + } + + sessionIDs := make([]SessionID, 0, len(flowIDs)) + + err := db.SelectContext(ctx, &sessionIDs, sqlWaitingSessionIDsForFlows, pq.Array(flowIDs)) + if err != nil { + return errors.Wrapf(err, "error selecting waiting sessions for flows") + } + + return errors.Wrapf(ExitSessions(ctx, db, sessionIDs, SessionStatusInterrupted), "error exiting sessions") +} diff --git a/core/models/sessions_test.go b/core/models/sessions_test.go new file mode 100644 index 000000000..5632b7ddd --- /dev/null +++ b/core/models/sessions_test.go @@ -0,0 +1,401 @@ +package models_test + +import ( + "context" + "os" + "testing" + + "github.com/buger/jsonparser" + "github.com/gomodule/redigo/redis" + "github.com/jmoiron/sqlx" + "github.com/nyaruka/gocommon/dbutil/assertdb" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/test" + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/testsuite" + "github.com/nyaruka/mailroom/testsuite/testdata" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSessionCreationAndUpdating(t *testing.T) { + ctx, rt, db, _ := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetData) + + assetsJSON, err := os.ReadFile("testdata/session_test_flows.json") + require.NoError(t, err) + + flowJSON, _, _, err := jsonparser.Get(assetsJSON, "flows", "[0]") + require.NoError(t, err) + flow := testdata.InsertFlow(db, testdata.Org1, flowJSON) + + oa, err := models.GetOrgAssetsWithRefresh(ctx, rt, testdata.Org1.ID, models.RefreshFlows) + require.NoError(t, err) + + flowSession, sprint1 := test.NewSessionBuilder().WithAssets(assetsJSON).WithFlow("c49daa28-cf70-407a-a767-a4c1360f4b01"). + WithContact(testdata.Bob.UUID, flows.ContactID(testdata.Bob.ID), "Bob", "eng", "").MustBuild() + + tx := db.MustBegin() + + hookCalls := 0 + hook := func(context.Context, *sqlx.Tx, *redis.Pool, *models.OrgAssets, []*models.Session) error { + hookCalls++ + return nil + } + + modelSessions, err := models.WriteSessions(ctx, rt, tx, oa, []flows.Session{flowSession}, []flows.Sprint{sprint1}, hook) + require.NoError(t, err) + assert.Equal(t, 1, hookCalls) + + require.NoError(t, tx.Commit()) + + session := modelSessions[0] + + assert.Equal(t, models.FlowTypeMessaging, session.SessionType()) + assert.Equal(t, testdata.Bob.ID, session.ContactID()) + assert.Equal(t, models.SessionStatusWaiting, session.Status()) + assert.Equal(t, flow.ID, session.CurrentFlowID()) + assert.NotNil(t, session.CreatedOn()) + assert.Nil(t, session.EndedOn()) + assert.False(t, session.Responded()) + assert.NotNil(t, session.WaitStartedOn()) + assert.NotNil(t, session.WaitExpiresOn()) + assert.False(t, *session.WaitResumeOnExpire()) + assert.NotNil(t, session.Timeout()) + + // check that matches what is in the db + assertdb.Query(t, db, `SELECT status, session_type, current_flow_id, responded, ended_on, wait_resume_on_expire FROM flows_flowsession`). + Columns(map[string]interface{}{ + "status": "W", "session_type": "M", "current_flow_id": int64(flow.ID), "responded": false, "ended_on": nil, "wait_resume_on_expire": false, + }) + + flowSession, err = session.FlowSession(rt.Config, oa.SessionAssets(), oa.Env()) + require.NoError(t, err) + + flowSession, sprint2, err := test.ResumeSession(flowSession, assetsJSON, "no") + require.NoError(t, err) + + tx = db.MustBegin() + + err = session.Update(ctx, rt, tx, oa, flowSession, sprint2, hook) + require.NoError(t, err) + assert.Equal(t, 2, hookCalls) + + require.NoError(t, tx.Commit()) + + assert.Equal(t, models.SessionStatusWaiting, session.Status()) + assert.Equal(t, flow.ID, session.CurrentFlowID()) + assert.True(t, session.Responded()) + assert.NotNil(t, session.WaitStartedOn()) + assert.NotNil(t, session.WaitExpiresOn()) + assert.False(t, *session.WaitResumeOnExpire()) + assert.Nil(t, session.Timeout()) // this wait doesn't have a timeout + + flowSession, err = session.FlowSession(rt.Config, oa.SessionAssets(), oa.Env()) + require.NoError(t, err) + + flowSession, sprint3, err := test.ResumeSession(flowSession, assetsJSON, "yes") + require.NoError(t, err) + + tx = db.MustBegin() + + err = session.Update(ctx, rt, tx, oa, flowSession, sprint3, hook) + require.NoError(t, err) + assert.Equal(t, 3, hookCalls) + + require.NoError(t, tx.Commit()) + + assert.Equal(t, models.SessionStatusCompleted, session.Status()) + assert.Equal(t, models.NilFlowID, session.CurrentFlowID()) // no longer "in" a flow + assert.True(t, session.Responded()) + assert.NotNil(t, session.CreatedOn()) + assert.Nil(t, session.WaitStartedOn()) + assert.Nil(t, session.WaitExpiresOn()) + assert.False(t, *session.WaitResumeOnExpire()) // stays false + assert.Nil(t, session.Timeout()) + assert.NotNil(t, session.EndedOn()) + + // check that matches what is in the db + assertdb.Query(t, db, `SELECT status, session_type, current_flow_id, responded FROM flows_flowsession`). + Columns(map[string]interface{}{"status": "C", "session_type": "M", "current_flow_id": nil, "responded": true}) +} + +func TestSingleSprintSession(t *testing.T) { + ctx, rt, db, _ := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetData) + + assetsJSON, err := os.ReadFile("testdata/session_test_flows.json") + require.NoError(t, err) + + flowJSON, _, _, err := jsonparser.Get(assetsJSON, "flows", "[1]") + require.NoError(t, err) + testdata.InsertFlow(db, testdata.Org1, flowJSON) + + oa, err := models.GetOrgAssetsWithRefresh(ctx, rt, testdata.Org1.ID, models.RefreshFlows) + require.NoError(t, err) + + flowSession, sprint1 := test.NewSessionBuilder().WithAssets(assetsJSON).WithFlow("8b1b02a0-e217-4d59-8ecb-3b20bec69cf4"). + WithContact(testdata.Bob.UUID, flows.ContactID(testdata.Bob.ID), "Bob", "eng", "").MustBuild() + + tx := db.MustBegin() + + hookCalls := 0 + hook := func(context.Context, *sqlx.Tx, *redis.Pool, *models.OrgAssets, []*models.Session) error { + hookCalls++ + return nil + } + + modelSessions, err := models.WriteSessions(ctx, rt, tx, oa, []flows.Session{flowSession}, []flows.Sprint{sprint1}, hook) + require.NoError(t, err) + assert.Equal(t, 1, hookCalls) + + require.NoError(t, tx.Commit()) + + session := modelSessions[0] + + assert.Equal(t, models.FlowTypeMessaging, session.SessionType()) + assert.Equal(t, testdata.Bob.ID, session.ContactID()) + assert.Equal(t, models.SessionStatusCompleted, session.Status()) + assert.Equal(t, models.NilFlowID, session.CurrentFlowID()) + assert.NotNil(t, session.CreatedOn()) + assert.NotNil(t, session.EndedOn()) + assert.False(t, session.Responded()) + assert.Nil(t, session.WaitStartedOn()) + assert.Nil(t, session.WaitExpiresOn()) + assert.Nil(t, session.Timeout()) + + // check that matches what is in the db + assertdb.Query(t, db, `SELECT status, session_type, current_flow_id, responded FROM flows_flowsession`). + Columns(map[string]interface{}{"status": "C", "session_type": "M", "current_flow_id": nil, "responded": false}) +} + +func TestSessionWithSubflows(t *testing.T) { + ctx, rt, db, _ := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetData) + + assetsJSON, err := os.ReadFile("testdata/session_test_flows.json") + require.NoError(t, err) + + parentJSON, _, _, err := jsonparser.Get(assetsJSON, "flows", "[2]") + require.NoError(t, err) + testdata.InsertFlow(db, testdata.Org1, parentJSON) + + childJSON, _, _, err := jsonparser.Get(assetsJSON, "flows", "[3]") + require.NoError(t, err) + childFlow := testdata.InsertFlow(db, testdata.Org1, childJSON) + + oa, err := models.GetOrgAssetsWithRefresh(ctx, rt, testdata.Org1.ID, models.RefreshFlows) + require.NoError(t, err) + + flowSession, sprint1 := test.NewSessionBuilder().WithAssets(assetsJSON).WithFlow("f128803a-9027-42b1-a707-f1dbe4cf88bd"). + WithContact(testdata.Bob.UUID, flows.ContactID(testdata.Cathy.ID), "Cathy", "eng", "").MustBuild() + + tx := db.MustBegin() + + hookCalls := 0 + hook := func(context.Context, *sqlx.Tx, *redis.Pool, *models.OrgAssets, []*models.Session) error { + hookCalls++ + return nil + } + + modelSessions, err := models.WriteSessions(ctx, rt, tx, oa, []flows.Session{flowSession}, []flows.Sprint{sprint1}, hook) + require.NoError(t, err) + assert.Equal(t, 1, hookCalls) + + require.NoError(t, tx.Commit()) + + session := modelSessions[0] + + assert.Equal(t, models.FlowTypeMessaging, session.SessionType()) + assert.Equal(t, testdata.Cathy.ID, session.ContactID()) + assert.Equal(t, models.SessionStatusWaiting, session.Status()) + assert.Equal(t, childFlow.ID, session.CurrentFlowID()) + assert.NotNil(t, session.CreatedOn()) + assert.Nil(t, session.EndedOn()) + assert.False(t, session.Responded()) + assert.NotNil(t, session.WaitStartedOn()) + assert.NotNil(t, session.WaitExpiresOn()) + assert.True(t, *session.WaitResumeOnExpire()) // because we have a parent + assert.Nil(t, session.Timeout()) + + // check that matches what is in the db + assertdb.Query(t, db, `SELECT status, session_type, current_flow_id, responded, ended_on, wait_resume_on_expire FROM flows_flowsession`). + Columns(map[string]interface{}{ + "status": "W", "session_type": "M", "current_flow_id": int64(childFlow.ID), "responded": false, "ended_on": nil, "wait_resume_on_expire": true, + }) + + flowSession, err = session.FlowSession(rt.Config, oa.SessionAssets(), oa.Env()) + require.NoError(t, err) + + flowSession, sprint2, err := test.ResumeSession(flowSession, assetsJSON, "yes") + require.NoError(t, err) + + tx = db.MustBegin() + + err = session.Update(ctx, rt, tx, oa, flowSession, sprint2, hook) + require.NoError(t, err) + assert.Equal(t, 2, hookCalls) + + require.NoError(t, tx.Commit()) + + assert.Equal(t, models.SessionStatusCompleted, session.Status()) + assert.Equal(t, models.NilFlowID, session.CurrentFlowID()) + assert.True(t, session.Responded()) + assert.Nil(t, session.WaitStartedOn()) + assert.Nil(t, session.WaitExpiresOn()) + assert.False(t, *session.WaitResumeOnExpire()) + assert.Nil(t, session.Timeout()) +} + +func TestInterruptSessionsForContacts(t *testing.T) { + ctx, _, db, _ := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetData) + + session1ID, _ := insertSessionAndRun(db, testdata.Cathy, models.FlowTypeMessaging, models.SessionStatusCompleted, testdata.Favorites, models.NilConnectionID) + session2ID, run2ID := insertSessionAndRun(db, testdata.Cathy, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, models.NilConnectionID) + session3ID, _ := insertSessionAndRun(db, testdata.Bob, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, models.NilConnectionID) + session4ID, _ := insertSessionAndRun(db, testdata.George, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, models.NilConnectionID) + + // noop if no contacts + err := models.InterruptSessionsForContacts(ctx, db, []models.ContactID{}) + require.NoError(t, err) + + assertSessionAndRunStatus(t, db, session1ID, models.SessionStatusCompleted) + assertSessionAndRunStatus(t, db, session2ID, models.SessionStatusWaiting) + assertSessionAndRunStatus(t, db, session3ID, models.SessionStatusWaiting) + assertSessionAndRunStatus(t, db, session4ID, models.SessionStatusWaiting) + + err = models.InterruptSessionsForContacts(ctx, db, []models.ContactID{testdata.Cathy.ID, testdata.Bob.ID}) + require.NoError(t, err) + + assertSessionAndRunStatus(t, db, session1ID, models.SessionStatusCompleted) // wasn't waiting + assertSessionAndRunStatus(t, db, session2ID, models.SessionStatusInterrupted) + assertSessionAndRunStatus(t, db, session3ID, models.SessionStatusInterrupted) + assertSessionAndRunStatus(t, db, session4ID, models.SessionStatusWaiting) // contact not included + + // check other columns are correct on interrupted session, run and contact + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE ended_on IS NOT NULL AND wait_started_on IS NULL AND wait_expires_on IS NULL AND timeout_on IS NULL AND current_flow_id IS NULL AND id = $1`, session2ID).Returns(1) + assertdb.Query(t, db, `SELECT is_active, exit_type FROM flows_flowrun WHERE id = $1`, run2ID).Columns(map[string]interface{}{"exit_type": "I", "is_active": false}) + assertdb.Query(t, db, `SELECT current_flow_id FROM contacts_contact WHERE id = $1`, testdata.Cathy.ID).Returns(nil) +} + +func TestInterruptSessionsOfTypeForContacts(t *testing.T) { + ctx, _, db, _ := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetData) + + session1ID, _ := insertSessionAndRun(db, testdata.Cathy, models.FlowTypeMessaging, models.SessionStatusCompleted, testdata.Favorites, models.NilConnectionID) + session2ID, _ := insertSessionAndRun(db, testdata.Cathy, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, models.NilConnectionID) + session3ID, _ := insertSessionAndRun(db, testdata.Bob, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, models.NilConnectionID) + session4ID, _ := insertSessionAndRun(db, testdata.George, models.FlowTypeVoice, models.SessionStatusWaiting, testdata.Favorites, models.NilConnectionID) + + tx := db.MustBegin() + + err := models.InterruptSessionsOfTypeForContacts(ctx, tx, []models.ContactID{testdata.Cathy.ID, testdata.Bob.ID, testdata.George.ID}, models.FlowTypeMessaging) + require.NoError(t, err) + + require.NoError(t, tx.Commit()) + + assertSessionAndRunStatus(t, db, session1ID, models.SessionStatusCompleted) // wasn't waiting + assertSessionAndRunStatus(t, db, session2ID, models.SessionStatusInterrupted) + assertSessionAndRunStatus(t, db, session3ID, models.SessionStatusInterrupted) + assertSessionAndRunStatus(t, db, session4ID, models.SessionStatusWaiting) // wrong type + + // check other columns are correct on interrupted session and contact + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE ended_on IS NOT NULL AND wait_started_on IS NULL AND wait_expires_on IS NULL AND timeout_on IS NULL AND current_flow_id IS NULL AND id = $1`, session2ID).Returns(1) + assertdb.Query(t, db, `SELECT current_flow_id FROM contacts_contact WHERE id = $1`, testdata.Cathy.ID).Returns(nil) +} + +func TestInterruptSessionsForChannels(t *testing.T) { + ctx, _, db, _ := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetData) + + cathy1ConnectionID := testdata.InsertConnection(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy) + cathy2ConnectionID := testdata.InsertConnection(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy) + bobConnectionID := testdata.InsertConnection(db, testdata.Org1, testdata.TwilioChannel, testdata.Bob) + georgeConnectionID := testdata.InsertConnection(db, testdata.Org1, testdata.VonageChannel, testdata.George) + + session1ID, _ := insertSessionAndRun(db, testdata.Cathy, models.FlowTypeMessaging, models.SessionStatusCompleted, testdata.Favorites, cathy1ConnectionID) + session2ID, _ := insertSessionAndRun(db, testdata.Cathy, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, cathy2ConnectionID) + session3ID, _ := insertSessionAndRun(db, testdata.Bob, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, bobConnectionID) + session4ID, _ := insertSessionAndRun(db, testdata.George, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, georgeConnectionID) + + // noop if no channels + err := models.InterruptSessionsForChannels(ctx, db, []models.ChannelID{}) + require.NoError(t, err) + + assertSessionAndRunStatus(t, db, session1ID, models.SessionStatusCompleted) + assertSessionAndRunStatus(t, db, session2ID, models.SessionStatusWaiting) + assertSessionAndRunStatus(t, db, session3ID, models.SessionStatusWaiting) + assertSessionAndRunStatus(t, db, session4ID, models.SessionStatusWaiting) + + err = models.InterruptSessionsForChannels(ctx, db, []models.ChannelID{testdata.TwilioChannel.ID}) + require.NoError(t, err) + + assertSessionAndRunStatus(t, db, session1ID, models.SessionStatusCompleted) // wasn't waiting + assertSessionAndRunStatus(t, db, session2ID, models.SessionStatusInterrupted) + assertSessionAndRunStatus(t, db, session3ID, models.SessionStatusInterrupted) + assertSessionAndRunStatus(t, db, session4ID, models.SessionStatusWaiting) // channel not included + + // check other columns are correct on interrupted session and contact + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE ended_on IS NOT NULL AND wait_started_on IS NULL AND wait_expires_on IS NULL AND timeout_on IS NULL AND current_flow_id IS NULL AND id = $1`, session2ID).Returns(1) + assertdb.Query(t, db, `SELECT current_flow_id FROM contacts_contact WHERE id = $1`, testdata.Cathy.ID).Returns(nil) +} + +func TestInterruptSessionsForFlows(t *testing.T) { + ctx, _, db, _ := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetData) + + cathy1ConnectionID := testdata.InsertConnection(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy) + cathy2ConnectionID := testdata.InsertConnection(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy) + bobConnectionID := testdata.InsertConnection(db, testdata.Org1, testdata.TwilioChannel, testdata.Bob) + georgeConnectionID := testdata.InsertConnection(db, testdata.Org1, testdata.VonageChannel, testdata.George) + + session1ID, _ := insertSessionAndRun(db, testdata.Cathy, models.FlowTypeMessaging, models.SessionStatusCompleted, testdata.Favorites, cathy1ConnectionID) + session2ID, _ := insertSessionAndRun(db, testdata.Cathy, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, cathy2ConnectionID) + session3ID, _ := insertSessionAndRun(db, testdata.Bob, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, bobConnectionID) + session4ID, _ := insertSessionAndRun(db, testdata.George, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.PickANumber, georgeConnectionID) + + // noop if no flows + err := models.InterruptSessionsForFlows(ctx, db, []models.FlowID{}) + require.NoError(t, err) + + assertSessionAndRunStatus(t, db, session1ID, models.SessionStatusCompleted) + assertSessionAndRunStatus(t, db, session2ID, models.SessionStatusWaiting) + assertSessionAndRunStatus(t, db, session3ID, models.SessionStatusWaiting) + assertSessionAndRunStatus(t, db, session4ID, models.SessionStatusWaiting) + + err = models.InterruptSessionsForFlows(ctx, db, []models.FlowID{testdata.Favorites.ID}) + require.NoError(t, err) + + assertSessionAndRunStatus(t, db, session1ID, models.SessionStatusCompleted) // wasn't waiting + assertSessionAndRunStatus(t, db, session2ID, models.SessionStatusInterrupted) + assertSessionAndRunStatus(t, db, session3ID, models.SessionStatusInterrupted) + assertSessionAndRunStatus(t, db, session4ID, models.SessionStatusWaiting) // flow not included + + // check other columns are correct on interrupted session and contact + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE ended_on IS NOT NULL AND wait_started_on IS NULL AND wait_expires_on IS NULL AND timeout_on IS NULL AND current_flow_id IS NULL AND id = $1`, session2ID).Returns(1) + assertdb.Query(t, db, `SELECT current_flow_id FROM contacts_contact WHERE id = $1`, testdata.Cathy.ID).Returns(nil) +} + +func insertSessionAndRun(db *sqlx.DB, contact *testdata.Contact, sessionType models.FlowType, status models.SessionStatus, flow *testdata.Flow, connID models.ConnectionID) (models.SessionID, models.FlowRunID) { + // create session and add a run with same status + sessionID := testdata.InsertFlowSession(db, testdata.Org1, contact, sessionType, status, flow, connID, nil) + runID := testdata.InsertFlowRun(db, testdata.Org1, sessionID, contact, flow, models.RunStatus(status), "", nil) + + // mark contact as being in that flow + db.MustExec(`UPDATE contacts_contact SET current_flow_id = $2 WHERE id = $1`, contact.ID, flow.ID) + + return sessionID, runID +} + +func assertSessionAndRunStatus(t *testing.T, db *sqlx.DB, sessionID models.SessionID, status models.SessionStatus) { + assertdb.Query(t, db, `SELECT status FROM flows_flowsession WHERE id = $1`, sessionID).Columns(map[string]interface{}{"status": string(status)}) + assertdb.Query(t, db, `SELECT status FROM flows_flowrun WHERE session_id = $1`, sessionID).Columns(map[string]interface{}{"status": string(status)}) +} diff --git a/core/models/starts_test.go b/core/models/starts_test.go index 74f1ebdbd..1866408b9 100644 --- a/core/models/starts_test.go +++ b/core/models/starts_test.go @@ -5,6 +5,7 @@ import ( "fmt" "testing" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/jsonx" "github.com/nyaruka/gocommon/uuids" "github.com/nyaruka/goflow/flows" @@ -64,8 +65,8 @@ func TestStarts(t *testing.T) { err = models.MarkStartStarted(ctx, db, startID, 2, []models.ContactID{testdata.George.ID}) require.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowstart WHERE id = $1 AND status = 'S' AND contact_count = 2`, startID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowstart_contacts WHERE flowstart_id = $1`, startID).Returns(3) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowstart WHERE id = $1 AND status = 'S' AND contact_count = 2`, startID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowstart_contacts WHERE flowstart_id = $1`, startID).Returns(3) batch := start.CreateBatch([]models.ContactID{testdata.Cathy.ID, testdata.Bob.ID}, false, 3) assert.Equal(t, startID, batch.StartID()) @@ -92,7 +93,7 @@ func TestStarts(t *testing.T) { err = models.MarkStartComplete(ctx, db, startID) require.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowstart WHERE id = $1 AND status = 'C'`, startID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowstart WHERE id = $1 AND status = 'C'`, startID).Returns(1) } func TestStartsBuilding(t *testing.T) { diff --git a/core/models/templates.go b/core/models/templates.go index 137bd7c6e..16afc5dc3 100644 --- a/core/models/templates.go +++ b/core/models/templates.go @@ -5,9 +5,9 @@ import ( "encoding/json" "time" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/envs" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/nyaruka/null" "github.com/jmoiron/sqlx" @@ -76,7 +76,7 @@ func loadTemplates(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets. templates := make([]assets.Template, 0) for rows.Next() { template := &Template{} - err = dbutil.ReadJSONRow(rows, &template.t) + err = dbutil.ScanAndValidateJSON(rows, &template.t) if err != nil { return nil, errors.Wrap(err, "error reading group row") } diff --git a/core/models/testdata/flow_stats_test.json b/core/models/testdata/flow_stats_test.json new file mode 100644 index 000000000..48beb7240 --- /dev/null +++ b/core/models/testdata/flow_stats_test.json @@ -0,0 +1,346 @@ +{ + "flows": [ + { + "uuid": "19eab6aa-4a88-42a1-8882-b9956823c680", + "name": "Flow Stats Test", + "revision": 75, + "spec_version": "13.1.0", + "type": "messaging", + "expire_after_minutes": 10080, + "language": "eng", + "localization": {}, + "nodes": [ + { + "uuid": "001b4eee-812f-403e-a004-737b948b3c18", + "actions": [ + { + "uuid": "d64f25cf-8b02-4ca9-8df8-3c457ccc1090", + "type": "send_msg", + "attachments": [], + "text": "Hi there! What's your favorite color?", + "quick_replies": [] + } + ], + "exits": [ + { + "uuid": "5fd2e537-0534-4c12-8425-bef87af09d46", + "destination_uuid": "072b95b3-61c3-4e0e-8dd1-eb7481083f94" + } + ] + }, + { + "uuid": "8712db6b-25ff-4789-892c-581f24eeeb95", + "actions": [ + { + "uuid": "1e65bf7a-fae7-4bac-94ae-662da02dfab8", + "type": "send_msg", + "attachments": [], + "text": "I'm sorry I don't know that color", + "quick_replies": [] + } + ], + "exits": [ + { + "uuid": "0a4f2ea9-c47f-4e9c-a242-89ae5b38d679", + "destination_uuid": "072b95b3-61c3-4e0e-8dd1-eb7481083f94" + } + ] + }, + { + "uuid": "072b95b3-61c3-4e0e-8dd1-eb7481083f94", + "actions": [], + "router": { + "type": "switch", + "default_category_uuid": "d7e0c791-c410-400b-bc34-985a537a425a", + "cases": [ + { + "arguments": [ + "red" + ], + "type": "has_any_word", + "uuid": "d75d478b-9713-46bd-8e8f-94a3ee9d4b86", + "category_uuid": "cb45b0d2-c55e-413d-b881-ed8805280a73" + }, + { + "arguments": [ + "green" + ], + "type": "has_any_word", + "uuid": "93918283-2438-403e-8160-3eea170c6f1e", + "category_uuid": "4cf5dd07-d5ac-4236-aef5-ae9ca6cb3e84" + }, + { + "arguments": [ + "blue" + ], + "type": "has_any_word", + "uuid": "990a80eb-ec55-43c5-b3e1-600bfe755556", + "category_uuid": "70c03b9d-923d-4a13-b844-1138b211f49b" + } + ], + "categories": [ + { + "uuid": "cb45b0d2-c55e-413d-b881-ed8805280a73", + "name": "Red", + "exit_uuid": "2fddfbe8-b239-47e6-8480-f22908e53b98" + }, + { + "uuid": "4cf5dd07-d5ac-4236-aef5-ae9ca6cb3e84", + "name": "Green", + "exit_uuid": "d24888e3-f2e1-4d44-8fb5-0362f8892563" + }, + { + "uuid": "70c03b9d-923d-4a13-b844-1138b211f49b", + "name": "Blue", + "exit_uuid": "c02fc3ba-369a-4c87-9bc4-c3b376bda6d2" + }, + { + "uuid": "d7e0c791-c410-400b-bc34-985a537a425a", + "name": "Other", + "exit_uuid": "ea6c38dc-11e2-4616-9f3e-577e44765d44" + } + ], + "operand": "@input.text", + "wait": { + "type": "msg" + }, + "result_name": "Color" + }, + "exits": [ + { + "uuid": "2fddfbe8-b239-47e6-8480-f22908e53b98", + "destination_uuid": "57b50d33-2b5a-4726-82de-9848c61eff6e" + }, + { + "uuid": "d24888e3-f2e1-4d44-8fb5-0362f8892563", + "destination_uuid": "57b50d33-2b5a-4726-82de-9848c61eff6e" + }, + { + "uuid": "c02fc3ba-369a-4c87-9bc4-c3b376bda6d2", + "destination_uuid": "57b50d33-2b5a-4726-82de-9848c61eff6e" + }, + { + "uuid": "ea6c38dc-11e2-4616-9f3e-577e44765d44", + "destination_uuid": "8712db6b-25ff-4789-892c-581f24eeeb95" + } + ] + }, + { + "uuid": "0e1fe072-6f03-4f29-98aa-7bedbe930dab", + "actions": [], + "router": { + "type": "switch", + "default_category_uuid": "e90847e5-a09d-4a5a-8c96-2157c5466576", + "categories": [ + { + "uuid": "ac3a2908-5141-47a6-9944-ec26e07c7b44", + "name": ">= 1", + "exit_uuid": "0bdbf661-1e6d-42fb-bd94-3bdac885b582" + }, + { + "uuid": "e90847e5-a09d-4a5a-8c96-2157c5466576", + "name": "Other", + "exit_uuid": "2b698218-87e5-4ab8-922e-e65f91d12c10" + } + ], + "cases": [ + { + "arguments": [ + "1" + ], + "type": "has_number_gte", + "uuid": "e61fb8a6-08e4-460a-8a07-1d37fb4a1827", + "category_uuid": "ac3a2908-5141-47a6-9944-ec26e07c7b44" + } + ], + "operand": "@(count(contact.tickets))" + }, + "exits": [ + { + "uuid": "0bdbf661-1e6d-42fb-bd94-3bdac885b582", + "destination_uuid": "88d8bf00-51ce-4e5e-aae8-4f957a0761a0" + }, + { + "uuid": "2b698218-87e5-4ab8-922e-e65f91d12c10", + "destination_uuid": "88d8bf00-51ce-4e5e-aae8-4f957a0761a0" + } + ] + }, + { + "uuid": "57b50d33-2b5a-4726-82de-9848c61eff6e", + "actions": [ + { + "attachments": [], + "text": "@results.color.category is a great color!", + "type": "send_msg", + "quick_replies": [], + "uuid": "d45b3ae7-52e3-4b93-a8bc-59502d364e5c" + } + ], + "exits": [ + { + "uuid": "97cd44ce-dec2-4e19-8ca2-4e20db51dc08", + "destination_uuid": "0e1fe072-6f03-4f29-98aa-7bedbe930dab" + } + ] + }, + { + "uuid": "88d8bf00-51ce-4e5e-aae8-4f957a0761a0", + "actions": [ + { + "attachments": [], + "text": "You have @(count(contact.tickets)) open tickets", + "type": "send_msg", + "quick_replies": [], + "uuid": "426773cc-cbc8-44e6-82ca-f6265862b1bb" + } + ], + "exits": [ + { + "uuid": "614e7451-e0bd-43d9-b317-2aded3c8d790", + "destination_uuid": "a1e649db-91e0-47c4-ab14-eba0d1475116" + } + ] + }, + { + "uuid": "a1e649db-91e0-47c4-ab14-eba0d1475116", + "actions": [], + "router": { + "type": "switch", + "cases": [ + { + "uuid": "52b4fb60-e998-467d-bd30-eaf0745bde71", + "type": "has_group", + "arguments": [ + "83452c2e-d6e4-4fae-950a-156064e40068", + "Customers" + ], + "category_uuid": "9687bab4-a6c2-4414-ba48-3d1bf3767acd" + } + ], + "categories": [ + { + "uuid": "9687bab4-a6c2-4414-ba48-3d1bf3767acd", + "name": "Customers", + "exit_uuid": "e7f24a98-6f75-4670-b6c3-fcaf7b4e29a6" + }, + { + "uuid": "052469bd-3fc6-4a10-ba97-c4c8763ecab8", + "name": "Other", + "exit_uuid": "574672d0-5976-4512-9173-1880aa0da2d7" + } + ], + "default_category_uuid": "052469bd-3fc6-4a10-ba97-c4c8763ecab8", + "operand": "@contact.groups", + "result_name": "" + }, + "exits": [ + { + "uuid": "e7f24a98-6f75-4670-b6c3-fcaf7b4e29a6", + "destination_uuid": "459dff50-c1e4-405f-84fa-8ed2b08df728" + }, + { + "uuid": "574672d0-5976-4512-9173-1880aa0da2d7", + "destination_uuid": null + } + ] + }, + { + "uuid": "459dff50-c1e4-405f-84fa-8ed2b08df728", + "actions": [ + { + "attachments": [], + "text": "You are a customer", + "type": "send_msg", + "quick_replies": [], + "uuid": "74fc312d-567c-4133-a95d-75f8b54ed594" + } + ], + "exits": [ + { + "uuid": "37918794-fa3e-4652-98ae-5549a2379af8", + "destination_uuid": null + } + ] + } + ], + "_ui": { + "nodes": { + "001b4eee-812f-403e-a004-737b948b3c18": { + "position": { + "left": 80, + "top": 0 + }, + "type": "execute_actions" + }, + "8712db6b-25ff-4789-892c-581f24eeeb95": { + "position": { + "left": 540, + "top": 80 + }, + "type": "execute_actions" + }, + "072b95b3-61c3-4e0e-8dd1-eb7481083f94": { + "type": "wait_for_response", + "position": { + "left": 220, + "top": 140 + }, + "config": { + "cases": {} + } + }, + "57b50d33-2b5a-4726-82de-9848c61eff6e": { + "position": { + "left": 140, + "top": 360 + }, + "type": "execute_actions" + }, + "0e1fe072-6f03-4f29-98aa-7bedbe930dab": { + "type": "split_by_expression", + "position": { + "left": 480, + "top": 340 + }, + "config": { + "cases": {} + } + }, + "88d8bf00-51ce-4e5e-aae8-4f957a0761a0": { + "position": { + "left": 440, + "top": 500 + }, + "type": "execute_actions" + }, + "a1e649db-91e0-47c4-ab14-eba0d1475116": { + "type": "split_by_groups", + "position": { + "left": 240, + "top": 660 + }, + "config": { + "cases": {} + } + }, + "459dff50-c1e4-405f-84fa-8ed2b08df728": { + "position": { + "left": 320, + "top": 800 + }, + "type": "execute_actions" + } + }, + "stickies": {} + } + } + ], + "groups": [ + { + "uuid": "83452c2e-d6e4-4fae-950a-156064e40068", + "name": "Customers", + "query": null + } + ] +} \ No newline at end of file diff --git a/core/models/testdata/session_test_flows.json b/core/models/testdata/session_test_flows.json new file mode 100644 index 000000000..e656479b2 --- /dev/null +++ b/core/models/testdata/session_test_flows.json @@ -0,0 +1,600 @@ +{ + "flows": [ + { + "name": "Two Questions", + "uuid": "c49daa28-cf70-407a-a767-a4c1360f4b01", + "spec_version": "13.1.0", + "language": "eng", + "type": "messaging", + "nodes": [ + { + "uuid": "8d3e3b71-0932-4e44-b8c8-99e15bac1f15", + "actions": [ + { + "attachments": [], + "text": "Do you like dogs?", + "type": "send_msg", + "quick_replies": [], + "uuid": "62aa7c4e-8b5d-436f-9799-efea2ee4736e" + } + ], + "exits": [ + { + "uuid": "18d3827d-6154-4ef1-890a-ee03cf26462c", + "destination_uuid": "cbff02b0-cd93-481d-a430-b335ab66779e" + } + ] + }, + { + "uuid": "f6d76a2a-2140-4283-bb6e-911adeb674f9", + "actions": [ + { + "attachments": [], + "text": "Sorry didn't understand that. Do you like dogs?", + "type": "send_msg", + "quick_replies": [], + "uuid": "f5697ef8-93fb-4bcf-8e61-8b0009562a76" + } + ], + "exits": [ + { + "uuid": "f1dcbf3e-180a-47fe-9a45-008d2de91539", + "destination_uuid": "cbff02b0-cd93-481d-a430-b335ab66779e" + } + ] + }, + { + "uuid": "cbff02b0-cd93-481d-a430-b335ab66779e", + "actions": [], + "router": { + "type": "switch", + "default_category_uuid": "b3c8664b-6fd8-4c80-b792-290ebaa82e16", + "cases": [ + { + "arguments": [ + "yes" + ], + "type": "has_any_word", + "uuid": "4ac9e7b0-decf-428a-b37f-09316be09198", + "category_uuid": "b0f5f049-f6a5-4901-ab8b-bfed481bc896" + }, + { + "arguments": [ + "no" + ], + "type": "has_any_word", + "uuid": "5bc0e00b-7f1e-4eac-9dba-bd682f0d4345", + "category_uuid": "efc08358-c694-4d1e-9b2b-9449df0f979c" + } + ], + "categories": [ + { + "uuid": "b0f5f049-f6a5-4901-ab8b-bfed481bc896", + "name": "Yes", + "exit_uuid": "6ba8ef10-829d-44ff-a7dc-07310c88c601" + }, + { + "uuid": "efc08358-c694-4d1e-9b2b-9449df0f979c", + "name": "No", + "exit_uuid": "2139a6a6-1861-4a32-96e9-691da424033e" + }, + { + "uuid": "b3c8664b-6fd8-4c80-b792-290ebaa82e16", + "name": "Other", + "exit_uuid": "6914d7c5-9784-47df-9b55-936692d6e9e7" + }, + { + "uuid": "799eac96-b7f6-4545-8e9c-46ebb4fc520b", + "name": "No Response", + "exit_uuid": "43ac015c-8614-4749-b24c-f4a4b0fc7dc3" + } + ], + "operand": "@input.text", + "wait": { + "type": "msg", + "timeout": { + "seconds": 300, + "category_uuid": "799eac96-b7f6-4545-8e9c-46ebb4fc520b" + } + }, + "result_name": "Likes Dogs" + }, + "exits": [ + { + "uuid": "6ba8ef10-829d-44ff-a7dc-07310c88c601", + "destination_uuid": "5e9edc6b-b0e9-4c02-a235-addcb331647f" + }, + { + "uuid": "2139a6a6-1861-4a32-96e9-691da424033e", + "destination_uuid": "5e9edc6b-b0e9-4c02-a235-addcb331647f" + }, + { + "uuid": "6914d7c5-9784-47df-9b55-936692d6e9e7", + "destination_uuid": "f6d76a2a-2140-4283-bb6e-911adeb674f9" + }, + { + "uuid": "43ac015c-8614-4749-b24c-f4a4b0fc7dc3", + "destination_uuid": "5e9edc6b-b0e9-4c02-a235-addcb331647f" + } + ] + }, + { + "uuid": "5e9edc6b-b0e9-4c02-a235-addcb331647f", + "actions": [ + { + "attachments": [], + "text": "Do you like cats?", + "type": "send_msg", + "quick_replies": [], + "uuid": "9d53826f-4e5c-4fd7-8e37-73d1163f2840" + } + ], + "exits": [ + { + "uuid": "bc9a0344-e817-483c-b942-1eb4d8bc7eec", + "destination_uuid": "bd8de388-811e-4116-ab41-8c2260d5514e" + } + ] + }, + { + "uuid": "93406d78-13ac-4447-97dc-021dfd79ba6f", + "actions": [ + { + "attachments": [], + "text": "Sorry didn't understand that. Do you like cats?", + "type": "send_msg", + "quick_replies": [], + "uuid": "1aabeffe-1ced-4ef2-a511-ac9ba1dde798" + } + ], + "exits": [ + { + "uuid": "7f911909-5c0c-4514-b8ea-2c227ffe60a1", + "destination_uuid": "bd8de388-811e-4116-ab41-8c2260d5514e" + } + ] + }, + { + "uuid": "bd8de388-811e-4116-ab41-8c2260d5514e", + "actions": [], + "router": { + "type": "switch", + "default_category_uuid": "f4a641dd-4e8b-4e92-9733-b03931bb4d2e", + "cases": [ + { + "arguments": [ + "yes" + ], + "type": "has_any_word", + "uuid": "4dcb05e1-cfb4-42b4-8b4d-cc35dc72f418", + "category_uuid": "f5b5de12-b11d-47b7-ba70-f2dc952f112d" + }, + { + "arguments": [ + "no" + ], + "type": "has_any_word", + "uuid": "b4b890c6-a2fe-431b-b6ce-7f8c8398b94f", + "category_uuid": "10eb0d04-3616-423f-bd91-4a59b50dc6d6" + } + ], + "categories": [ + { + "uuid": "f5b5de12-b11d-47b7-ba70-f2dc952f112d", + "name": "Yes", + "exit_uuid": "a792d8cb-53dd-4cd3-9ca7-99b67a645f61" + }, + { + "uuid": "10eb0d04-3616-423f-bd91-4a59b50dc6d6", + "name": "No", + "exit_uuid": "854e3bfd-828f-4537-a639-9b717e19b591" + }, + { + "uuid": "f4a641dd-4e8b-4e92-9733-b03931bb4d2e", + "name": "Other", + "exit_uuid": "7686cfaa-1d6b-403a-bf56-fc8fb1277390" + } + ], + "operand": "@input.text", + "wait": { + "type": "msg" + }, + "result_name": "Likes Cats" + }, + "exits": [ + { + "uuid": "a792d8cb-53dd-4cd3-9ca7-99b67a645f61", + "destination_uuid": "5953e6c9-e6be-4ecb-92a2-bfd6003b2bad" + }, + { + "uuid": "854e3bfd-828f-4537-a639-9b717e19b591", + "destination_uuid": "5953e6c9-e6be-4ecb-92a2-bfd6003b2bad" + }, + { + "uuid": "7686cfaa-1d6b-403a-bf56-fc8fb1277390", + "destination_uuid": "93406d78-13ac-4447-97dc-021dfd79ba6f" + } + ] + }, + { + "uuid": "5953e6c9-e6be-4ecb-92a2-bfd6003b2bad", + "actions": [ + { + "attachments": [], + "text": "Thank you", + "type": "send_msg", + "quick_replies": [], + "uuid": "2b565c34-5846-48dc-927e-876ea2d65288" + } + ], + "exits": [ + { + "uuid": "89c5c2e2-e2af-414b-b03e-30327f84da12", + "destination_uuid": null + } + ] + } + ], + "_ui": { + "nodes": { + "8d3e3b71-0932-4e44-b8c8-99e15bac1f15": { + "position": { + "left": 100, + "top": 0 + }, + "type": "execute_actions" + }, + "cbff02b0-cd93-481d-a430-b335ab66779e": { + "type": "wait_for_response", + "position": { + "left": 100, + "top": 120 + }, + "config": { + "cases": {} + } + }, + "f6d76a2a-2140-4283-bb6e-911adeb674f9": { + "position": { + "left": 420, + "top": 60 + }, + "type": "execute_actions" + }, + "5e9edc6b-b0e9-4c02-a235-addcb331647f": { + "position": { + "left": 100, + "top": 320 + }, + "type": "execute_actions" + }, + "bd8de388-811e-4116-ab41-8c2260d5514e": { + "type": "wait_for_response", + "position": { + "left": 100, + "top": 440 + }, + "config": { + "cases": {} + } + }, + "5953e6c9-e6be-4ecb-92a2-bfd6003b2bad": { + "position": { + "left": 100, + "top": 620 + }, + "type": "execute_actions" + }, + "93406d78-13ac-4447-97dc-021dfd79ba6f": { + "position": { + "left": 420, + "top": 380 + }, + "type": "execute_actions" + } + } + }, + "revision": 31, + "expire_after_minutes": 10080, + "localization": {} + }, + { + "name": "Single Message", + "uuid": "8b1b02a0-e217-4d59-8ecb-3b20bec69cf4", + "spec_version": "13.1.0", + "language": "eng", + "type": "messaging", + "nodes": [ + { + "uuid": "7e5c2d93-dfcd-4531-8048-8ec7aa5f6cd6", + "actions": [ + { + "attachments": [], + "text": "Just wanted to say hi", + "type": "send_msg", + "quick_replies": [], + "uuid": "30ed7a4d-d5d3-41c6-942f-4d82ed3cb86c" + } + ], + "exits": [ + { + "uuid": "472a3585-c0e0-442c-9a9b-064ec4c15088", + "destination_uuid": null + } + ] + } + ], + "revision": 31, + "expire_after_minutes": 10080, + "localization": {} + }, + { + "name": "Subflow: Parent", + "uuid": "f128803a-9027-42b1-a707-f1dbe4cf88bd", + "spec_version": "13.1.0", + "language": "eng", + "type": "messaging", + "nodes": [ + { + "uuid": "69710037-4f39-495a-91b2-2eae89ca69f0", + "actions": [ + { + "attachments": [], + "text": "This is the parent flow", + "type": "send_msg", + "quick_replies": [], + "uuid": "44976a5a-2872-4567-972b-9823a2cb617c" + } + ], + "exits": [ + { + "uuid": "b014074b-c6cf-4531-ad4f-0e1bb9a0b1f1", + "destination_uuid": "ef926afe-d42a-4d5b-8867-9dbfaeb5f176" + } + ] + }, + { + "uuid": "ef926afe-d42a-4d5b-8867-9dbfaeb5f176", + "actions": [ + { + "uuid": "c4fb407f-1864-4878-b930-c1b97ac9482a", + "type": "enter_flow", + "flow": { + "uuid": "4403b147-61ba-41ec-a2d2-11a38f910761", + "name": "Subflow: Child" + } + } + ], + "router": { + "type": "switch", + "operand": "@child.status", + "cases": [ + { + "uuid": "2700b4e9-3017-4cd2-8914-8303fb05883e", + "type": "has_only_text", + "arguments": [ + "completed" + ], + "category_uuid": "261583e4-9ea5-425f-ba60-73461d2cdae1" + }, + { + "uuid": "62b7c542-0aa0-4812-b9d5-d11db737b835", + "arguments": [ + "expired" + ], + "type": "has_only_text", + "category_uuid": "b554a9f1-11f9-48fc-9d4c-5206f671e026" + } + ], + "categories": [ + { + "uuid": "261583e4-9ea5-425f-ba60-73461d2cdae1", + "name": "Complete", + "exit_uuid": "d1c90ac8-385b-42b2-930e-9d85dca8670a" + }, + { + "uuid": "b554a9f1-11f9-48fc-9d4c-5206f671e026", + "name": "Expired", + "exit_uuid": "d13f2800-43b8-442d-a245-4f199c869ed6" + } + ], + "default_category_uuid": "b554a9f1-11f9-48fc-9d4c-5206f671e026" + }, + "exits": [ + { + "uuid": "d1c90ac8-385b-42b2-930e-9d85dca8670a", + "destination_uuid": "2886a2a0-ad95-4811-81ed-f955c8e6f239" + }, + { + "uuid": "d13f2800-43b8-442d-a245-4f199c869ed6", + "destination_uuid": "772b5eea-40a0-4786-8b1d-1cac08ed2912" + } + ] + }, + { + "uuid": "2886a2a0-ad95-4811-81ed-f955c8e6f239", + "actions": [ + { + "attachments": [], + "text": "You completed the child flow", + "type": "send_msg", + "quick_replies": [], + "uuid": "815a98d5-9bc4-446b-be64-0e154babff64" + } + ], + "exits": [ + { + "uuid": "32e99848-42af-4d24-b86c-c9be82c383cb", + "destination_uuid": null + } + ] + }, + { + "uuid": "772b5eea-40a0-4786-8b1d-1cac08ed2912", + "actions": [ + { + "attachments": [], + "text": "You expired from the child flow", + "type": "send_msg", + "quick_replies": [], + "uuid": "a05a0bdc-451d-42fa-8b76-a2d8df2d9c5a" + } + ], + "exits": [ + { + "uuid": "c96f5fbb-3ccf-41be-a55d-32b22f86382e", + "destination_uuid": null + } + ] + } + ], + "_ui": { + "nodes": { + "69710037-4f39-495a-91b2-2eae89ca69f0": { + "position": { + "left": 180, + "top": 0 + }, + "type": "execute_actions" + }, + "ef926afe-d42a-4d5b-8867-9dbfaeb5f176": { + "type": "split_by_subflow", + "position": { + "left": 180, + "top": 120 + }, + "config": {} + }, + "2886a2a0-ad95-4811-81ed-f955c8e6f239": { + "position": { + "left": 60, + "top": 280 + }, + "type": "execute_actions" + }, + "772b5eea-40a0-4786-8b1d-1cac08ed2912": { + "position": { + "left": 280, + "top": 280 + }, + "type": "execute_actions" + } + } + }, + "revision": 8, + "expire_after_minutes": 10080, + "localization": {} + }, + { + "name": "Subflow: Child", + "uuid": "4403b147-61ba-41ec-a2d2-11a38f910761", + "spec_version": "13.1.0", + "language": "eng", + "type": "messaging", + "nodes": [ + { + "uuid": "7525b836-b61c-4fbb-9b89-8539d75d7304", + "actions": [ + { + "attachments": [], + "text": "This is the child flow. Do you like it?", + "type": "send_msg", + "quick_replies": [], + "uuid": "0947e077-fe92-4520-8f2a-6a2e9dd4c881" + } + ], + "exits": [ + { + "uuid": "0f3913af-1aa4-4909-96b6-e79ba17986ae", + "destination_uuid": "03068be2-4748-48e5-b19b-228b5412ebd5" + } + ] + }, + { + "uuid": "03068be2-4748-48e5-b19b-228b5412ebd5", + "actions": [], + "router": { + "type": "switch", + "default_category_uuid": "f87bdf83-d6ca-4b11-814e-13a23e4a874b", + "cases": [ + { + "arguments": [ + "yes" + ], + "type": "has_any_word", + "uuid": "149d2946-633f-4ccc-b2a5-d4e8d26e2492", + "category_uuid": "9f7fd998-fc6e-4c50-9ace-a4027badcbbc" + }, + { + "arguments": [ + "no" + ], + "type": "has_any_word", + "uuid": "63c7cbb8-4386-410b-b74f-e4c7febeb625", + "category_uuid": "9ea7ea2c-8508-4214-84e8-af4f865f7205" + } + ], + "categories": [ + { + "uuid": "9f7fd998-fc6e-4c50-9ace-a4027badcbbc", + "name": "Yes", + "exit_uuid": "55c98810-b9a6-4ff2-b789-264770d4f313" + }, + { + "uuid": "9ea7ea2c-8508-4214-84e8-af4f865f7205", + "name": "No", + "exit_uuid": "81414700-fed0-425b-89c9-bc9f0a6a40c9" + }, + { + "uuid": "f87bdf83-d6ca-4b11-814e-13a23e4a874b", + "name": "Other", + "exit_uuid": "3d22e845-678a-44a6-a3b7-1471b841f198" + } + ], + "operand": "@input.text", + "wait": { + "type": "msg" + }, + "result_name": "Result 1" + }, + "exits": [ + { + "uuid": "55c98810-b9a6-4ff2-b789-264770d4f313" + }, + { + "uuid": "81414700-fed0-425b-89c9-bc9f0a6a40c9" + }, + { + "uuid": "3d22e845-678a-44a6-a3b7-1471b841f198", + "destination_uuid": null + } + ] + } + ], + "_ui": { + "nodes": { + "7525b836-b61c-4fbb-9b89-8539d75d7304": { + "position": { + "left": 0, + "top": 0 + }, + "type": "execute_actions" + }, + "03068be2-4748-48e5-b19b-228b5412ebd5": { + "type": "wait_for_response", + "position": { + "left": 160, + "top": 140 + }, + "config": { + "cases": {} + } + } + } + }, + "revision": 5, + "expire_after_minutes": 10080, + "localization": {} + } + ] +} \ No newline at end of file diff --git a/core/models/ticket_events_test.go b/core/models/ticket_events_test.go index 891caa004..34ca21ca4 100644 --- a/core/models/ticket_events_test.go +++ b/core/models/ticket_events_test.go @@ -3,6 +3,7 @@ package models_test import ( "testing" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" @@ -54,7 +55,7 @@ func TestTicketEvents(t *testing.T) { err := models.InsertTicketEvents(ctx, db, []*models.TicketEvent{e1, e2, e3, e4, e5}) require.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticketevent`).Returns(5) - testsuite.AssertQuery(t, db, `SELECT assignee_id, note FROM tickets_ticketevent WHERE id = $1`, e2.ID()). + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticketevent`).Returns(5) + assertdb.Query(t, db, `SELECT assignee_id, note FROM tickets_ticketevent WHERE id = $1`, e2.ID()). Columns(map[string]interface{}{"assignee_id": int64(testdata.Agent.ID), "note": "please handle"}) } diff --git a/core/models/tickets.go b/core/models/tickets.go index 044547e72..f1dbc458e 100644 --- a/core/models/tickets.go +++ b/core/models/tickets.go @@ -4,11 +4,11 @@ import ( "context" "database/sql" "database/sql/driver" - "fmt" "net/http" "time" "github.com/nyaruka/gocommon/dates" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/gocommon/httpx" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/flows" @@ -16,7 +16,6 @@ import ( "github.com/nyaruka/goflow/utils" "github.com/nyaruka/mailroom/core/goflow" "github.com/nyaruka/mailroom/runtime" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/nyaruka/null" "github.com/jmoiron/sqlx" @@ -153,8 +152,8 @@ func (t *Ticket) FlowTicket(oa *OrgAssets) (*flows.Ticket, error) { } // ForwardIncoming forwards an incoming message from a contact to this ticket -func (t *Ticket) ForwardIncoming(ctx context.Context, rt *runtime.Runtime, org *OrgAssets, msgUUID flows.MsgUUID, text string, attachments []utils.Attachment) error { - ticketer := org.TicketerByID(t.t.TicketerID) +func (t *Ticket) ForwardIncoming(ctx context.Context, rt *runtime.Runtime, oa *OrgAssets, msgUUID flows.MsgUUID, text string, attachments []utils.Attachment) error { + ticketer := oa.TicketerByID(t.t.TicketerID) if ticketer == nil { return errors.Errorf("can't find ticketer with id %d", t.t.TicketerID) } @@ -479,7 +478,6 @@ func TicketsChangeTopic(ctx context.Context, db Queryer, oa *OrgAssets, userID U now := dates.Now() for _, ticket := range tickets { - fmt.Printf("ticket #%d topic=%d\n", ticket.ID(), ticket.TopicID()) if ticket.TopicID() != topicID { ids = append(ids, ticket.ID()) t := &ticket.t @@ -788,7 +786,7 @@ func LookupTicketerByUUID(ctx context.Context, db Queryer, uuid assets.TicketerU } ticketer := &Ticketer{} - err = dbutil.ReadJSONRow(rows, &ticketer.t) + err = dbutil.ScanJSON(rows, &ticketer.t) if err != nil { return nil, errors.Wrapf(err, "error unmarshalling ticketer") } @@ -827,7 +825,7 @@ func loadTicketers(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets. ticketers := make([]assets.Ticketer, 0, 2) for rows.Next() { ticketer := &Ticketer{} - err := dbutil.ReadJSONRow(rows, &ticketer.t) + err := dbutil.ScanJSON(rows, &ticketer.t) if err != nil { return nil, errors.Wrapf(err, "error unmarshalling ticketer") } diff --git a/core/models/tickets_test.go b/core/models/tickets_test.go index 0d5fad535..10eaeb4ce 100644 --- a/core/models/tickets_test.go +++ b/core/models/tickets_test.go @@ -5,6 +5,7 @@ import ( "time" "github.com/nyaruka/gocommon/dates" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/httpx" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/mailroom/core/models" @@ -111,7 +112,7 @@ func TestTickets(t *testing.T) { assert.NoError(t, err) // check all tickets were created - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticket WHERE status = 'O' AND closed_on IS NULL`).Returns(3) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticket WHERE status = 'O' AND closed_on IS NULL`).Returns(3) // can lookup a ticket by UUID tk1, err := models.LookupTicketByUUID(ctx, db, "2ef57efc-d85f-4291-b330-e4afe68af5fe") @@ -143,16 +144,16 @@ func TestUpdateTicketConfig(t *testing.T) { modelTicket := ticket.Load(db) // empty configs are null - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticket WHERE config IS NULL AND id = $1`, ticket.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticket WHERE config IS NULL AND id = $1`, ticket.ID).Returns(1) models.UpdateTicketConfig(ctx, db, modelTicket, map[string]string{"foo": "2352", "bar": "abc"}) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticket WHERE config='{"foo": "2352", "bar": "abc"}'::jsonb AND id = $1`, ticket.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticket WHERE config='{"foo": "2352", "bar": "abc"}'::jsonb AND id = $1`, ticket.ID).Returns(1) // updates are additive models.UpdateTicketConfig(ctx, db, modelTicket, map[string]string{"foo": "6547", "zed": "xyz"}) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticket WHERE config='{"foo": "6547", "bar": "abc", "zed": "xyz"}'::jsonb AND id = $1`, ticket.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticket WHERE config='{"foo": "6547", "bar": "abc", "zed": "xyz"}'::jsonb AND id = $1`, ticket.ID).Returns(1) } func TestUpdateTicketLastActivity(t *testing.T) { @@ -172,7 +173,7 @@ func TestUpdateTicketLastActivity(t *testing.T) { assert.Equal(t, now, modelTicket.LastActivityOn()) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticket WHERE id = $1 AND last_activity_on = $2`, ticket.ID, modelTicket.LastActivityOn()).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticket WHERE id = $1 AND last_activity_on = $2`, ticket.ID, modelTicket.LastActivityOn()).Returns(1) } @@ -199,13 +200,13 @@ func TestTicketsAssign(t *testing.T) { assert.Equal(t, models.TicketEventTypeAssigned, evts[modelTicket2].EventType()) // check tickets are now assigned - testsuite.AssertQuery(t, db, `SELECT assignee_id FROM tickets_ticket WHERE id = $1`, ticket1.ID).Columns(map[string]interface{}{"assignee_id": int64(testdata.Agent.ID)}) - testsuite.AssertQuery(t, db, `SELECT assignee_id FROM tickets_ticket WHERE id = $1`, ticket2.ID).Columns(map[string]interface{}{"assignee_id": int64(testdata.Agent.ID)}) + assertdb.Query(t, db, `SELECT assignee_id FROM tickets_ticket WHERE id = $1`, ticket1.ID).Columns(map[string]interface{}{"assignee_id": int64(testdata.Agent.ID)}) + assertdb.Query(t, db, `SELECT assignee_id FROM tickets_ticket WHERE id = $1`, ticket2.ID).Columns(map[string]interface{}{"assignee_id": int64(testdata.Agent.ID)}) // and there are new assigned events - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE event_type = 'A' AND note = 'please handle these'`).Returns(2) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE event_type = 'A' AND note = 'please handle these'`).Returns(2) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM notifications_notification WHERE user_id = $1 AND notification_type = 'tickets:activity'`, testdata.Agent.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM notifications_notification WHERE user_id = $1 AND notification_type = 'tickets:activity'`, testdata.Agent.ID).Returns(1) } func TestTicketsAddNote(t *testing.T) { @@ -231,9 +232,9 @@ func TestTicketsAddNote(t *testing.T) { assert.Equal(t, models.TicketEventTypeNoteAdded, evts[modelTicket2].EventType()) // check there are new note events - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE event_type = 'N' AND note = 'spam'`).Returns(2) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE event_type = 'N' AND note = 'spam'`).Returns(2) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM notifications_notification WHERE user_id = $1 AND notification_type = 'tickets:activity'`, testdata.Agent.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM notifications_notification WHERE user_id = $1 AND notification_type = 'tickets:activity'`, testdata.Agent.ID).Returns(1) } func TestTicketsChangeTopic(t *testing.T) { @@ -262,8 +263,8 @@ func TestTicketsChangeTopic(t *testing.T) { assert.Equal(t, models.TicketEventTypeTopicChanged, evts[modelTicket3].EventType()) // check tickets are updated and we have events - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticket WHERE topic_id = $1`, testdata.SupportTopic.ID).Returns(3) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE event_type = 'T' AND topic_id = $1`, testdata.SupportTopic.ID).Returns(2) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticket WHERE topic_id = $1`, testdata.SupportTopic.ID).Returns(3) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE event_type = 'T' AND topic_id = $1`, testdata.SupportTopic.ID).Returns(2) } func TestCloseTickets(t *testing.T) { @@ -307,16 +308,16 @@ func TestCloseTickets(t *testing.T) { assert.Equal(t, models.TicketEventTypeClosed, evts[modelTicket1].EventType()) // check ticket #1 is now closed - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticket WHERE id = $1 AND status = 'C' AND closed_on IS NOT NULL`, ticket1.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticket WHERE id = $1 AND status = 'C' AND closed_on IS NOT NULL`, ticket1.ID).Returns(1) // and there's closed event for it - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE org_id = $1 AND ticket_id = $2 AND event_type = 'C'`, + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE org_id = $1 AND ticket_id = $2 AND event_type = 'C'`, []interface{}{testdata.Org1.ID, ticket1.ID}, 1) // and the logger has an http log it can insert for that ticketer require.NoError(t, logger.Insert(ctx, db)) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM request_logs_httplog WHERE ticketer_id = $1`, testdata.Mailgun.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM request_logs_httplog WHERE ticketer_id = $1`, testdata.Mailgun.ID).Returns(1) // reload Cathy and check they're no longer in the tickets group _, cathy = testdata.Cathy.Load(db, oa) @@ -324,7 +325,7 @@ func TestCloseTickets(t *testing.T) { assert.Equal(t, "Doctors", cathy.Groups().All()[0].Name()) // but no events for ticket #2 which was already closed - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE ticket_id = $1 AND event_type = 'C'`, ticket2.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE ticket_id = $1 AND event_type = 'C'`, ticket2.ID).Returns(0) // can close tickets without a user ticket3 := testdata.InsertOpenTicket(db, testdata.Org1, testdata.Cathy, testdata.Mailgun, testdata.DefaultTopic, "Where my shoes", "123", nil) @@ -335,7 +336,7 @@ func TestCloseTickets(t *testing.T) { assert.Equal(t, 1, len(evts)) assert.Equal(t, models.TicketEventTypeClosed, evts[modelTicket3].EventType()) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE ticket_id = $1 AND event_type = 'C' AND created_by_id IS NULL`, ticket3.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE ticket_id = $1 AND event_type = 'C' AND created_by_id IS NULL`, ticket3.ID).Returns(1) } func TestReopenTickets(t *testing.T) { @@ -371,18 +372,18 @@ func TestReopenTickets(t *testing.T) { assert.Equal(t, models.TicketEventTypeReopened, evts[modelTicket1].EventType()) // check ticket #1 is now closed - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticket WHERE id = $1 AND status = 'O' AND closed_on IS NULL`, ticket1.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticket WHERE id = $1 AND status = 'O' AND closed_on IS NULL`, ticket1.ID).Returns(1) // and there's reopened event for it - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE org_id = $1 AND ticket_id = $2 AND event_type = 'R'`, testdata.Org1.ID, ticket1.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE org_id = $1 AND ticket_id = $2 AND event_type = 'R'`, testdata.Org1.ID, ticket1.ID).Returns(1) // and the logger has an http log it can insert for that ticketer require.NoError(t, logger.Insert(ctx, db)) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM request_logs_httplog WHERE ticketer_id = $1`, testdata.Mailgun.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM request_logs_httplog WHERE ticketer_id = $1`, testdata.Mailgun.ID).Returns(1) // but no events for ticket #2 which waas already open - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE ticket_id = $1 AND event_type = 'R'`, ticket2.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticketevent WHERE ticket_id = $1 AND event_type = 'R'`, ticket2.ID).Returns(0) // check Cathy is now in the two tickets group _, cathy := testdata.Cathy.Load(db, oa) diff --git a/core/models/topics.go b/core/models/topics.go index 238ebcdaa..9b01aac23 100644 --- a/core/models/topics.go +++ b/core/models/topics.go @@ -8,8 +8,8 @@ import ( "github.com/jmoiron/sqlx" "github.com/nyaruka/gocommon/dates" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/goflow/assets" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/nyaruka/null" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -74,7 +74,7 @@ func loadTopics(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets.Top topics := make([]assets.Topic, 0, 2) for rows.Next() { topic := &Topic{} - err := dbutil.ReadJSONRow(rows, &topic.t) + err := dbutil.ScanJSON(rows, &topic.t) if err != nil { return nil, errors.Wrapf(err, "error unmarshalling topic") } diff --git a/core/models/triggers.go b/core/models/triggers.go index 094c267ba..0e969800e 100644 --- a/core/models/triggers.go +++ b/core/models/triggers.go @@ -6,12 +6,11 @@ import ( "strings" "time" + "github.com/lib/pq" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/goflow/flows/triggers" "github.com/nyaruka/goflow/utils" - "github.com/nyaruka/mailroom/utils/dbutil" - - "github.com/lib/pq" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -105,7 +104,7 @@ func loadTriggers(ctx context.Context, db Queryer, orgID OrgID) ([]*Trigger, err triggers := make([]*Trigger, 0, 10) for rows.Next() { trigger := &Trigger{} - err = dbutil.ReadJSONRow(rows, &trigger.t) + err = dbutil.ScanJSON(rows, &trigger.t) if err != nil { return nil, errors.Wrap(err, "error scanning label row") } diff --git a/core/models/users.go b/core/models/users.go index a3af281b5..551fc0442 100644 --- a/core/models/users.go +++ b/core/models/users.go @@ -8,8 +8,8 @@ import ( "time" "github.com/jmoiron/sqlx" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/goflow/assets" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/nyaruka/null" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -126,7 +126,7 @@ func loadUsers(ctx context.Context, db sqlx.Queryer, orgID OrgID) ([]assets.User users := make([]assets.User, 0, 10) for rows.Next() { user := &User{} - err := dbutil.ReadJSONRow(rows, &user.u) + err := dbutil.ScanJSON(rows, &user.u) if err != nil { return nil, errors.Wrapf(err, "error unmarshalling user") } diff --git a/core/models/utils.go b/core/models/utils.go index 2c8d3307b..a6d76b2e8 100644 --- a/core/models/utils.go +++ b/core/models/utils.go @@ -7,7 +7,7 @@ import ( "time" "github.com/jmoiron/sqlx" - "github.com/nyaruka/mailroom/utils/dbutil" + "github.com/nyaruka/gocommon/dbutil" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -18,6 +18,7 @@ type Queryer interface { ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) + SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error GetContext(ctx context.Context, value interface{}, query string, args ...interface{}) error } @@ -90,3 +91,17 @@ func chunkSlice(slice []interface{}, size int) [][]interface{} { } return chunks } + +// chunks a slice of session IDs.. hurry up go generics +func chunkSessionIDs(ids []SessionID, size int) [][]SessionID { + chunks := make([][]SessionID, 0, len(ids)/size+1) + + for i := 0; i < len(ids); i += size { + end := i + size + if end > len(ids) { + end = len(ids) + } + chunks = append(chunks, ids[i:end]) + } + return chunks +} diff --git a/core/models/utils_test.go b/core/models/utils_test.go index 4299d0398..4e7b4ea58 100644 --- a/core/models/utils_test.go +++ b/core/models/utils_test.go @@ -3,6 +3,7 @@ package models_test import ( "testing" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/testsuite" @@ -35,8 +36,8 @@ func TestBulkQueryBatches(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, foo1.ID) assert.Equal(t, 2, foo2.ID) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM foo WHERE name = 'A' AND age = 30`).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM foo WHERE name = 'B' AND age = 31`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM foo WHERE name = 'A' AND age = 30`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM foo WHERE name = 'B' AND age = 31`).Returns(1) // test when multiple batches are required foo3 := &foo{Name: "C", Age: 32} @@ -51,10 +52,10 @@ func TestBulkQueryBatches(t *testing.T) { assert.Equal(t, 5, foo5.ID) assert.Equal(t, 6, foo6.ID) assert.Equal(t, 7, foo7.ID) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM foo WHERE name = 'C' AND age = 32`).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM foo WHERE name = 'D' AND age = 33`).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM foo WHERE name = 'E' AND age = 34`).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM foo WHERE name = 'F' AND age = 35`).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM foo WHERE name = 'G' AND age = 36`).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM foo `).Returns(7) + assertdb.Query(t, db, `SELECT count(*) FROM foo WHERE name = 'C' AND age = 32`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM foo WHERE name = 'D' AND age = 33`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM foo WHERE name = 'E' AND age = 34`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM foo WHERE name = 'F' AND age = 35`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM foo WHERE name = 'G' AND age = 36`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM foo `).Returns(7) } diff --git a/core/models/webhook_event_test.go b/core/models/webhook_event_test.go index 562686134..761a5ebcd 100644 --- a/core/models/webhook_event_test.go +++ b/core/models/webhook_event_test.go @@ -4,6 +4,7 @@ import ( "testing" "time" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" @@ -32,6 +33,6 @@ func TestWebhookEvents(t *testing.T) { assert.NoError(t, err) assert.NotZero(t, e.ID()) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM api_webhookevent WHERE org_id = $1 AND resthook_id = $2 AND data = $3`, tc.OrgID, tc.ResthookID, tc.Data).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM api_webhookevent WHERE org_id = $1 AND resthook_id = $2 AND data = $3`, tc.OrgID, tc.ResthookID, tc.Data).Returns(1) } } diff --git a/core/msgio/android.go b/core/msgio/android.go index 6b50f0ceb..aa1fb6be9 100644 --- a/core/msgio/android.go +++ b/core/msgio/android.go @@ -19,6 +19,8 @@ func SyncAndroidChannels(fc *fcm.Client, channels []*models.Channel) { } for _, channel := range channels { + assert(channel.Type() == models.ChannelTypeAndroid, "can't sync a non-android channel") + // no FCM ID for this channel, noop, we can't trigger a sync fcmID := channel.ConfigValue(models.ChannelConfigFCMID, "") if fcmID == "" { diff --git a/core/msgio/courier.go b/core/msgio/courier.go index f4d603d0c..b7b4021b5 100644 --- a/core/msgio/courier.go +++ b/core/msgio/courier.go @@ -1,52 +1,85 @@ package msgio import ( - "encoding/json" "strconv" "time" + "github.com/nyaruka/gocommon/dates" + "github.com/nyaruka/gocommon/jsonx" "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/mailroom/core/models" "github.com/gomodule/redigo/redis" - "github.com/pkg/errors" "github.com/sirupsen/logrus" ) const ( - highPriority = 1 - defaultPriority = 0 + bulkPriority = 0 + highPriority = 1 ) +var queuePushScript = redis.NewScript(6, ` +-- KEYS: [QueueType, QueueName, TPS, Priority, Items, EpochSecs] +local queueType, queueName, tps, priority, items, epochSecs = KEYS[1], KEYS[2], tonumber(KEYS[3]), KEYS[4], KEYS[5], KEYS[6] + +-- first construct the base key for this queue from the type + name + tps, e.g. "msgs:0a77a158-1dcb-4c06-9aee-e15bdf64653e|10" +local queueKey = queueType .. ":" .. queueName .. "|" .. tps + +-- each queue than has two sorted sets for bulk and high priority items, e.g. "msgs:0a77..653e|10/0" vs msgs:0a77..653e|10/1" +local priorityQueueKey = queueKey .. "/" .. priority + +-- add the items to the sorted set using the full timestamp (e.g. 1636556789.123456) as the score +redis.call("ZADD", priorityQueueKey, epochSecs, items) + +-- if we have a TPS limit, check the transaction counter for this epoch second to see if have already reached it +local curr = -1 +if tps > 0 then + local tpsKey = queueKey .. ":tps:" .. math.floor(epochSecs) -- e.g. "msgs:0a77..4653e|10:tps:1636556789" + curr = tonumber(redis.call("GET", tpsKey)) +end + +-- if we haven't hit the limit, add this queue to set of active queues +if not curr or curr < tps then + redis.call("ZINCRBY", queueType .. ":active", 0, queueKey) + return 1 +else + return 0 +end +`) + +// PushCourierBatch pushes a batch of messages for a single contact and channel onto the appropriate courier queue +func PushCourierBatch(rc redis.Conn, ch *models.Channel, batch []*models.Msg, timestamp string) error { + priority := bulkPriority + if batch[0].HighPriority() { + priority = highPriority + } + batchJSON := jsonx.MustMarshal(batch) + + _, err := queuePushScript.Do(rc, "msgs", ch.UUID(), ch.TPS(), priority, batchJSON, timestamp) + return err +} + // QueueCourierMessages queues messages for a single contact to Courier func QueueCourierMessages(rc redis.Conn, contactID models.ContactID, msgs []*models.Msg) error { if len(msgs) == 0 { return nil } - now := time.Now() - epochMS := strconv.FormatFloat(float64(now.UnixNano()/int64(time.Microsecond))/float64(1000000), 'f', 6, 64) - - priority := defaultPriority + // get the time in seconds since the epoch as a floating point number + // e.g. 2021-11-10T15:10:49.123456+00:00 => "1636557205.123456" + now := dates.Now() + epochSeconds := strconv.FormatFloat(float64(now.UnixNano()/int64(time.Microsecond))/float64(1000000), 'f', 6, 64) // we batch msgs by channel uuid batch := make([]*models.Msg, 0, len(msgs)) currentChannel := msgs[0].Channel() + currentPriority := msgs[0].HighPriority() // commits our batch to redis commitBatch := func() error { if len(batch) > 0 { - priority = defaultPriority - if batch[0].HighPriority() { - priority = highPriority - } - - batchJSON, err := json.Marshal(batch) - if err != nil { - return err - } start := time.Now() - _, err = queueMsg.Do(rc, epochMS, "msgs", currentChannel.UUID(), currentChannel.TPS(), priority, batchJSON) + err := PushCourierBatch(rc, currentChannel, batch, epochSeconds) if err != nil { return err } @@ -61,39 +94,21 @@ func QueueCourierMessages(rc redis.Conn, contactID models.ContactID, msgs []*mod } for _, msg := range msgs { - // android messages should never get in here - if msg.Channel() != nil && msg.Channel().Type() == models.ChannelTypeAndroid { - panic("trying to queue android messages to courier") - } - - // ignore any message without a channel or already marked as failed (maybe org is suspended) - if msg.ChannelUUID() == "" || msg.Status() == models.MsgStatusFailed { - continue - } - - // nil channel object but have channel UUID? that's an error - if msg.Channel() == nil { - return errors.Errorf("msg passed in without channel set") - } - - // no contact urn id or urn, also an error - if msg.URN() == urns.NilURN || msg.ContactURNID() == nil { - return errors.Errorf("msg passed with nil urn: %s", msg.URN()) - } + // sanity check the state of the msg we're about to queue... + assert(msg.Channel() != nil && msg.ChannelUUID() != "", "can't queue a message to courier without a channel") + assert(msg.Channel().Type() != models.ChannelTypeAndroid, "can't queue an android message to courier") + assert(msg.URN() != urns.NilURN && msg.ContactURNID() != nil, "can't queue a message to courier without a URN") - // same channel? add to batch - if msg.Channel() == currentChannel { + // if this msg is the same channel and priority, add to current batch, otherwise start new batch + if msg.Channel() == currentChannel && msg.HighPriority() == currentPriority { batch = append(batch, msg) - } - - // different channel? queue it up - if msg.Channel() != currentChannel { - err := commitBatch() - if err != nil { + } else { + if err := commitBatch(); err != nil { return err } currentChannel = msg.Channel() + currentPriority = msg.HighPriority() batch = []*models.Msg{msg} } } @@ -101,31 +116,3 @@ func QueueCourierMessages(rc redis.Conn, contactID models.ContactID, msgs []*mod // any remaining in our batch, queue it up return commitBatch() } - -var queueMsg = redis.NewScript(6, ` --- KEYS: [EpochMS, QueueType, QueueName, TPS, Priority, Value] - --- first push onto our specific queue --- our queue name is built from the type, name and tps, usually something like: "msgs:uuid1-uuid2-uuid3-uuid4|tps" -local queueKey = KEYS[2] .. ":" .. KEYS[3] .. "|" .. KEYS[4] - --- our priority queue name also includes the priority of the message (we have one queue for default and one for bulk) -local priorityQueueKey = queueKey .. "/" .. KEYS[5] -redis.call("zadd", priorityQueueKey, KEYS[1], KEYS[6]) -local tps = tonumber(KEYS[4]) - --- if we have a TPS, check whether we are currently throttled -local curr = -1 -if tps > 0 then - local tpsKey = queueKey .. ":tps:" .. math.floor(KEYS[1]) - curr = tonumber(redis.call("get", tpsKey)) -end - --- if we aren't then add to our active -if not curr or curr < tps then -redis.call("zincrby", KEYS[2] .. ":active", 0, queueKey) - return 1 -else - return 0 -end -`) diff --git a/core/msgio/courier_test.go b/core/msgio/courier_test.go index 2605f062e..eb06ac870 100644 --- a/core/msgio/courier_test.go +++ b/core/msgio/courier_test.go @@ -1,8 +1,11 @@ package msgio_test import ( + "encoding/json" "testing" + "github.com/gomodule/redigo/redis" + "github.com/nyaruka/gocommon/jsonx" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/core/msgio" "github.com/nyaruka/mailroom/testsuite" @@ -17,7 +20,7 @@ func TestQueueCourierMessages(t *testing.T) { rc := rp.Get() defer rc.Close() - defer testsuite.Reset(testsuite.ResetAll) + defer testsuite.Reset(testsuite.ResetData | testsuite.ResetRedis) // create an Andoid channel androidChannel := testdata.InsertChannel(db, testdata.Org1, "A", "Android 1", []string{"tel"}, "SR", map[string]interface{}{"FCM_ID": "FCMID"}) @@ -25,76 +28,104 @@ func TestQueueCourierMessages(t *testing.T) { oa, err := models.GetOrgAssetsWithRefresh(ctx, rt, testdata.Org1.ID, models.RefreshOrg|models.RefreshChannels) require.NoError(t, err) - tests := []struct { - Description string - Msgs []msgSpec - QueueSizes map[string][]int - }{ - { - Description: "2 queueable messages", - Msgs: []msgSpec{ - { - ChannelID: testdata.TwilioChannel.ID, - ContactID: testdata.Cathy.ID, - URNID: testdata.Cathy.URNID, - }, - { - ChannelID: testdata.TwilioChannel.ID, - ContactID: testdata.Cathy.ID, - URNID: testdata.Cathy.URNID, - }, - }, - QueueSizes: map[string][]int{ - "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10/0": {2}, - }, - }, - { - Description: "1 queueable message and 1 failed", - Msgs: []msgSpec{ - { - ChannelID: testdata.TwilioChannel.ID, - ContactID: testdata.Cathy.ID, - URNID: testdata.Cathy.URNID, - Failed: true, - }, - { - ChannelID: testdata.TwilioChannel.ID, - ContactID: testdata.Cathy.ID, - URNID: testdata.Cathy.URNID, - }, - }, - QueueSizes: map[string][]int{ - "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10/0": {1}, - }, - }, - { - Description: "0 messages", - Msgs: []msgSpec{}, - QueueSizes: map[string][]int{}, - }, + // noop if no messages provided + msgio.QueueCourierMessages(rc, testdata.Cathy.ID, []*models.Msg{}) + testsuite.AssertCourierQueues(t, map[string][]int{}) + + // queue 3 messages for Cathy.. + msgs := []*models.Msg{ + (&msgSpec{Channel: testdata.TwilioChannel, Contact: testdata.Cathy}).createMsg(t, rt, oa), + (&msgSpec{Channel: testdata.TwilioChannel, Contact: testdata.Cathy}).createMsg(t, rt, oa), + (&msgSpec{Channel: testdata.TwilioChannel, Contact: testdata.Cathy, HighPriority: true}).createMsg(t, rt, oa), + (&msgSpec{Channel: testdata.VonageChannel, Contact: testdata.Cathy}).createMsg(t, rt, oa), } - for _, tc := range tests { - var contactID models.ContactID - msgs := make([]*models.Msg, len(tc.Msgs)) - for i, ms := range tc.Msgs { - msgs[i] = ms.createMsg(t, rt, oa) - contactID = ms.ContactID - } + msgio.QueueCourierMessages(rc, testdata.Cathy.ID, msgs) - rc.Do("FLUSHDB") - msgio.QueueCourierMessages(rc, contactID, msgs) + testsuite.AssertCourierQueues(t, map[string][]int{ + "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10/0": {2}, // twilio, bulk priority + "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10/1": {1}, // twilio, high priority + "msgs:19012bfd-3ce3-4cae-9bb9-76cf92c73d49|10/0": {1}, // vonage, bulk priority + }) - testsuite.AssertCourierQueues(t, tc.QueueSizes, "courier queue sizes mismatch in '%s'", tc.Description) - } + // check that trying to queue a message without a channel will panic + assert.Panics(t, func() { + ms := msgSpec{Channel: nil, Contact: testdata.Cathy} + msgio.QueueCourierMessages(rc, testdata.Cathy.ID, []*models.Msg{ms.createMsg(t, rt, oa)}) + }) - // check that trying to queue a courier message will panic + // check that trying to queue an Android message will panic assert.Panics(t, func() { - ms := msgSpec{ - ChannelID: androidChannel.ID, - ContactID: testdata.Cathy.ID, - URNID: testdata.Cathy.URNID, - } + ms := msgSpec{Channel: androidChannel, Contact: testdata.Cathy} msgio.QueueCourierMessages(rc, testdata.Cathy.ID, []*models.Msg{ms.createMsg(t, rt, oa)}) }) } + +func TestPushCourierBatch(t *testing.T) { + ctx, rt, _, rp := testsuite.Get() + rc := rp.Get() + defer rc.Close() + + defer testsuite.Reset(testsuite.ResetData | testsuite.ResetRedis) + + oa, err := models.GetOrgAssetsWithRefresh(ctx, rt, testdata.Org1.ID, models.RefreshChannels) + require.NoError(t, err) + + channel := oa.ChannelByID(testdata.TwilioChannel.ID) + + msg1 := (&msgSpec{Channel: testdata.TwilioChannel, Contact: testdata.Cathy}).createMsg(t, rt, oa) + msg2 := (&msgSpec{Channel: testdata.TwilioChannel, Contact: testdata.Cathy}).createMsg(t, rt, oa) + + err = msgio.PushCourierBatch(rc, channel, []*models.Msg{msg1, msg2}, "1636557205.123456") + require.NoError(t, err) + + // check that channel has been added to active list + msgsActive, err := redis.Strings(rc.Do("ZRANGE", "msgs:active", 0, -1)) + assert.NoError(t, err) + assert.Equal(t, []string{"msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10"}, msgsActive) + + // and that msgs were added as single batch to bulk priority (0) queue + queued, err := redis.ByteSlices(rc.Do("ZRANGE", "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10/0", 0, -1)) + assert.NoError(t, err) + assert.Equal(t, 1, len(queued)) + + unmarshaled, err := jsonx.DecodeGeneric(queued[0]) + assert.NoError(t, err) + assert.Equal(t, 2, len(unmarshaled.([]interface{}))) + + item1ID, _ := unmarshaled.([]interface{})[0].(map[string]interface{})["id"].(json.Number).Int64() + item2ID, _ := unmarshaled.([]interface{})[1].(map[string]interface{})["id"].(json.Number).Int64() + assert.Equal(t, int64(msg1.ID()), item1ID) + assert.Equal(t, int64(msg2.ID()), item2ID) + + // push another batch in the same epoch second with transaction counter still below limit + rc.Do("SET", "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10:tps:1636557205", "5") + + msg3 := (&msgSpec{Channel: testdata.TwilioChannel, Contact: testdata.Cathy}).createMsg(t, rt, oa) + + err = msgio.PushCourierBatch(rc, channel, []*models.Msg{msg3}, "1636557205.234567") + require.NoError(t, err) + + queued, err = redis.ByteSlices(rc.Do("ZRANGE", "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10/0", 0, -1)) + assert.NoError(t, err) + assert.Equal(t, 2, len(queued)) + + // simulate channel having been throttled + rc.Do("ZREM", "msgs:active", "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10") + rc.Do("SET", "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10:tps:1636557205", "11") + + msg4 := (&msgSpec{Channel: testdata.TwilioChannel, Contact: testdata.Cathy}).createMsg(t, rt, oa) + + err = msgio.PushCourierBatch(rc, channel, []*models.Msg{msg4}, "1636557205.345678") + require.NoError(t, err) + + // check that channel has *not* been added to active list + msgsActive, err = redis.Strings(rc.Do("ZRANGE", "msgs:active", 0, -1)) + assert.NoError(t, err) + assert.Equal(t, []string{}, msgsActive) + + // but msg was still added to queue + queued, err = redis.ByteSlices(rc.Do("ZRANGE", "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10/0", 0, -1)) + assert.NoError(t, err) + assert.Equal(t, 3, len(queued)) +} diff --git a/core/msgio/send.go b/core/msgio/send.go index 0d8b6d6cd..c118048e9 100644 --- a/core/msgio/send.go +++ b/core/msgio/send.go @@ -24,6 +24,11 @@ func SendMessages(ctx context.Context, rt *runtime.Runtime, tx models.Queryer, f // walk through our messages, separate by whether they have a channel and if it's Android for _, msg := range msgs { + // ignore any message already marked as failed (maybe org is suspended) + if msg.Status() == models.MsgStatusFailed { + continue + } + channel := msg.Channel() if channel != nil { if channel.Type() == models.ChannelTypeAndroid { @@ -75,3 +80,9 @@ func SendMessages(ctx context.Context, rt *runtime.Runtime, tx models.Queryer, f } } } + +func assert(c bool, m string) { + if !c { + panic(m) + } +} diff --git a/core/msgio/send_test.go b/core/msgio/send_test.go index 0b4eaf029..e3ff4d859 100644 --- a/core/msgio/send_test.go +++ b/core/msgio/send_test.go @@ -2,13 +2,9 @@ package msgio_test import ( "context" - "fmt" "testing" - "time" - "github.com/nyaruka/gocommon/urns" - "github.com/nyaruka/goflow/assets" - "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/core/msgio" "github.com/nyaruka/mailroom/runtime" @@ -20,36 +16,29 @@ import ( ) type msgSpec struct { - ChannelID models.ChannelID - ContactID models.ContactID - URNID models.URNID - Failed bool + Channel *testdata.Channel + Contact *testdata.Contact + Failed bool + HighPriority bool } func (m *msgSpec) createMsg(t *testing.T, rt *runtime.Runtime, oa *models.OrgAssets) *models.Msg { - // Only way to create a failed outgoing message is to suspend the org and reload the org. - // However the channels have to be fetched from the same org assets thus why this uses its - // own org assets instance. - ctx := context.Background() - rt.DB.MustExec(`UPDATE orgs_org SET is_suspended = $1 WHERE id = $2`, m.Failed, testdata.Org1.ID) - oaOrg, _ := models.GetOrgAssetsWithRefresh(ctx, rt, testdata.Org1.ID, models.RefreshOrg) - - var channel *models.Channel - var channelRef *assets.ChannelReference - - if m.ChannelID != models.NilChannelID { - channel = oa.ChannelByID(m.ChannelID) - channelRef = channel.ChannelReference() + status := models.MsgStatusQueued + if m.Failed { + status = models.MsgStatusFailed } - urn := urns.URN(fmt.Sprintf("tel:+250700000001?id=%d", m.URNID)) - flowMsg := flows.NewMsgOut(urn, channelRef, "Hello", nil, nil, nil, flows.NilMsgTopic) - msg, err := models.NewOutgoingMsg(rt.Config, oaOrg.Org(), channel, m.ContactID, flowMsg, time.Now()) + flowMsg := testdata.InsertOutgoingMsg(rt.DB, testdata.Org1, m.Channel, m.Contact, "Hello", nil, status, m.HighPriority) + msgs, err := models.GetMessagesByID(context.Background(), rt.DB, testdata.Org1.ID, models.DirectionOut, []models.MsgID{models.MsgID(flowMsg.ID())}) require.NoError(t, err) - models.InsertMessages(ctx, rt.DB, []*models.Msg{msg}) - require.NoError(t, err) + msg := msgs[0] + msg.SetURN(m.Contact.URN) + // use the channel instances in org assets so they're shared between msg instances + if msg.ChannelID() != models.NilChannelID { + msg.SetChannel(oa.ChannelByID(msg.ChannelID())) + } return msg } @@ -58,6 +47,8 @@ func TestSendMessages(t *testing.T) { rc := rp.Get() defer rc.Close() + defer testsuite.Reset(testsuite.ResetData) + mockFCM := newMockFCMEndpoint("FCMID3") defer mockFCM.Stop() @@ -78,27 +69,37 @@ func TestSendMessages(t *testing.T) { FCMTokensSynced []string PendingMsgs int }{ + { + Description: "no messages", + Msgs: []msgSpec{}, + QueueSizes: map[string][]int{}, + FCMTokensSynced: []string{}, + PendingMsgs: 0, + }, { Description: "2 messages for Courier, and 1 Android", Msgs: []msgSpec{ { - ChannelID: testdata.TwilioChannel.ID, - ContactID: testdata.Cathy.ID, - URNID: testdata.Cathy.URNID, + Channel: testdata.TwilioChannel, + Contact: testdata.Cathy, + }, + { + Channel: androidChannel1, + Contact: testdata.Bob, }, { - ChannelID: androidChannel1.ID, - ContactID: testdata.Bob.ID, - URNID: testdata.Bob.URNID, + Channel: testdata.TwilioChannel, + Contact: testdata.Cathy, }, { - ChannelID: testdata.TwilioChannel.ID, - ContactID: testdata.Cathy.ID, - URNID: testdata.Cathy.URNID, + Channel: testdata.TwilioChannel, + Contact: testdata.Bob, + HighPriority: true, }, }, QueueSizes: map[string][]int{ - "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10/0": {2}, + "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10/0": {2}, // 2 default priority messages for Cathy + "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10/1": {1}, // 1 high priority message for Bob }, FCMTokensSynced: []string{"FCMID1"}, PendingMsgs: 0, @@ -107,32 +108,41 @@ func TestSendMessages(t *testing.T) { Description: "each Android channel synced once", Msgs: []msgSpec{ { - ChannelID: androidChannel1.ID, - ContactID: testdata.Cathy.ID, - URNID: testdata.Cathy.URNID, + Channel: androidChannel1, + Contact: testdata.Cathy, }, { - ChannelID: androidChannel2.ID, - ContactID: testdata.Bob.ID, - URNID: testdata.Bob.URNID, + Channel: androidChannel2, + Contact: testdata.Bob, }, { - ChannelID: androidChannel1.ID, - ContactID: testdata.Cathy.ID, - URNID: testdata.Cathy.URNID, + Channel: androidChannel1, + Contact: testdata.Cathy, }, }, QueueSizes: map[string][]int{}, FCMTokensSynced: []string{"FCMID1", "FCMID2"}, PendingMsgs: 0, }, + { + Description: "messages with FAILED status ignored", + Msgs: []msgSpec{ + { + Channel: testdata.TwilioChannel, + Contact: testdata.Cathy, + Failed: true, + }, + }, + QueueSizes: map[string][]int{}, + FCMTokensSynced: []string{}, + PendingMsgs: 0, + }, { Description: "messages without channels set to PENDING", Msgs: []msgSpec{ { - ChannelID: models.NilChannelID, - ContactID: testdata.Cathy.ID, - URNID: testdata.Cathy.URNID, + Channel: nil, + Contact: testdata.Cathy, }, }, QueueSizes: map[string][]int{}, @@ -162,6 +172,6 @@ func TestSendMessages(t *testing.T) { assert.Equal(t, tc.FCMTokensSynced, actualTokens, "FCM tokens mismatch in '%s'", tc.Description) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'P'`).Returns(tc.PendingMsgs, `pending messages mismatch in '%s'`, tc.Description) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'P'`).Returns(tc.PendingMsgs, `pending messages mismatch in '%s'`, tc.Description) } } diff --git a/core/runner/runner.go b/core/runner/runner.go index 3f259327c..2c1c7c71b 100644 --- a/core/runner/runner.go +++ b/core/runner/runner.go @@ -73,7 +73,7 @@ func ResumeFlow(ctx context.Context, rt *runtime.Runtime, oa *models.OrgAssets, // if this flow just isn't available anymore, log this error if err == models.ErrNotFound { logrus.WithField("contact_uuid", session.Contact().UUID()).WithField("session_id", session.ID()).WithField("flow_id", session.CurrentFlowID()).Error("unable to find flow in resume") - return nil, models.ExitSessions(ctx, rt.DB, []models.SessionID{session.ID()}, models.ExitFailed, time.Now()) + return nil, models.ExitSessions(ctx, rt.DB, []models.SessionID{session.ID()}, models.SessionStatusFailed) } return nil, errors.Wrapf(err, "error loading session flow: %d", session.CurrentFlowID()) } @@ -104,7 +104,7 @@ func ResumeFlow(ctx context.Context, rt *runtime.Runtime, oa *models.OrgAssets, } // write our updated session and runs - err = session.WriteUpdatedSession(txCTX, rt, tx, oa, fs, sprint, hook) + err = session.Update(txCTX, rt, tx, oa, fs, sprint, hook) if err != nil { tx.Rollback() return nil, errors.Wrapf(err, "error updating session for resume") @@ -204,14 +204,14 @@ func StartFlowBatch( // this will build our trigger for each contact started triggerBuilder := func(contact *flows.Contact) flows.Trigger { if batch.ParentSummary() != nil { - tb := triggers.NewBuilder(oa.Env(), flow.FlowReference(), contact).FlowAction(history, batch.ParentSummary()) + tb := triggers.NewBuilder(oa.Env(), flow.Reference(), contact).FlowAction(history, batch.ParentSummary()) if batchStart { tb = tb.AsBatch() } return tb.Build() } - tb := triggers.NewBuilder(oa.Env(), flow.FlowReference(), contact).Manual() + tb := triggers.NewBuilder(oa.Env(), flow.Reference(), contact).Manual() if batch.Extra() != nil { tb = tb.WithParams(params) } @@ -581,14 +581,14 @@ func StartFlowForContacts( } // build our list of contact ids - contactIDs := make([]flows.ContactID, len(triggers)) + contactIDs := make([]models.ContactID, len(triggers)) for i := range triggers { - contactIDs[i] = triggers[i].Contact().ID() + contactIDs[i] = models.ContactID(triggers[i].Contact().ID()) } // interrupt all our contacts if desired if interrupt { - err = models.InterruptContactRuns(txCTX, tx, flow.FlowType(), contactIDs, start) + err = models.InterruptSessionsOfTypeForContacts(txCTX, tx, contactIDs, flow.FlowType()) if err != nil { tx.Rollback() return nil, errors.Wrap(err, "error interrupting contacts") @@ -628,7 +628,7 @@ func StartFlowForContacts( // interrupt this contact if appropriate if interrupt { - err = models.InterruptContactRuns(txCTX, tx, flow.FlowType(), []flows.ContactID{session.Contact().ID()}, start) + err = models.InterruptSessionsOfTypeForContacts(txCTX, tx, []models.ContactID{models.ContactID(session.Contact().ID())}, flow.FlowType()) if err != nil { tx.Rollback() log.WithField("contact_uuid", session.Contact().UUID()).WithError(err).Errorf("error interrupting contact") diff --git a/core/runner/runner_test.go b/core/runner/runner_test.go index 1121eb8b0..a87067fc7 100644 --- a/core/runner/runner_test.go +++ b/core/runner/runner_test.go @@ -1,11 +1,16 @@ package runner_test import ( + "context" "encoding/json" "testing" "time" + "github.com/gomodule/redigo/redis" + "github.com/jmoiron/sqlx" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/uuids" + "github.com/nyaruka/goflow/envs" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/goflow/flows/resumes" "github.com/nyaruka/goflow/flows/triggers" @@ -14,6 +19,7 @@ import ( "github.com/nyaruka/mailroom/core/runner" "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" + "github.com/nyaruka/mailroom/utils/test" "github.com/lib/pq" "github.com/stretchr/testify/assert" @@ -29,13 +35,15 @@ func TestCampaignStarts(t *testing.T) { // create our event fires now := time.Now() - db.MustExec(`INSERT INTO campaigns_eventfire(event_id, scheduled, contact_id) VALUES($1, $2, $3),($1, $2, $4),($1, $2, $5);`, testdata.RemindersEvent2.ID, now, testdata.Cathy.ID, testdata.Bob.ID, testdata.Alexandria.ID) + testdata.InsertEventFire(rt.DB, testdata.Cathy, testdata.RemindersEvent2, now) + testdata.InsertEventFire(rt.DB, testdata.Bob, testdata.RemindersEvent2, now) + testdata.InsertEventFire(rt.DB, testdata.Alexandria, testdata.RemindersEvent2, now) - // create an active session for Alexandria to test skipping - db.MustExec(`INSERT INTO flows_flowsession(uuid, session_type, org_id, contact_id, status, responded, created_on, current_flow_id) VALUES($1, 'M', $2, $3, 'W', FALSE, NOW(), $4);`, uuids.New(), testdata.Org1.ID, testdata.Alexandria.ID, testdata.PickANumber.ID) + // create an waiting session for Alexandria to test skipping + testdata.InsertFlowSession(db, testdata.Org1, testdata.Alexandria, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.PickANumber, models.NilConnectionID, nil) // create an active voice call for Cathy to make sure it doesn't get interrupted or cause skipping - db.MustExec(`INSERT INTO flows_flowsession(uuid, session_type, org_id, contact_id, status, responded, created_on, current_flow_id) VALUES($1, 'V', $2, $3, 'W', FALSE, NOW(), $4);`, uuids.New(), testdata.Org1.ID, testdata.Cathy.ID, testdata.IVRFlow.ID) + testdata.InsertFlowSession(db, testdata.Org1, testdata.Cathy, models.FlowTypeVoice, models.SessionStatusWaiting, testdata.IVRFlow, models.NilConnectionID, nil) // set our event to skip db.MustExec(`UPDATE campaigns_campaignevent SET start_mode = 'S' WHERE id= $1`, testdata.RemindersEvent2.ID) @@ -65,35 +73,34 @@ func TestCampaignStarts(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 2, len(sessions), "expected only two sessions to be created") - testsuite.AssertQuery(t, db, + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE contact_id = ANY($1) AND status = 'C' AND responded = FALSE AND org_id = 1 AND connection_id IS NULL AND output IS NOT NULL`, pq.Array(contacts)). Returns(2, "expected only two sessions to be created") - testsuite.AssertQuery(t, db, + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE contact_id = ANY($1) and flow_id = $2 - AND is_active = FALSE AND responded = FALSE AND org_id = 1 AND parent_id IS NULL AND exit_type = 'C' AND status = 'C' - AND results IS NOT NULL AND path IS NOT NULL AND events IS NOT NULL - AND session_id IS NOT NULL`, + AND is_active = FALSE AND responded = FALSE AND org_id = 1 AND exit_type = 'C' AND status = 'C' + AND results IS NOT NULL AND path IS NOT NULL AND session_id IS NOT NULL`, pq.Array(contacts), testdata.CampaignFlow.ID).Returns(2, "expected only two runs to be created") - testsuite.AssertQuery(t, db, + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = ANY($1) AND text like '% it is time to consult with your patients.' AND org_id = 1 AND status = 'Q' AND queued_on IS NOT NULL AND direction = 'O' AND topup_id IS NOT NULL AND msg_type = 'F' AND channel_id = $2`, pq.Array(contacts), testdata.TwilioChannel.ID).Returns(2, "expected only two messages to be sent") - testsuite.AssertQuery(t, db, `SELECT count(*) from campaigns_eventfire WHERE fired IS NULL`). + assertdb.Query(t, db, `SELECT count(*) from campaigns_eventfire WHERE fired IS NULL`). Returns(0, "expected all events to be fired") - testsuite.AssertQuery(t, db, + assertdb.Query(t, db, `SELECT count(*) from campaigns_eventfire WHERE fired IS NOT NULL AND contact_id IN ($1,$2) AND event_id = $3 AND fired_result = 'F'`, testdata.Cathy.ID, testdata.Bob.ID, testdata.RemindersEvent2.ID). Returns(2, "expected bob and cathy to have their event sent to fired") - testsuite.AssertQuery(t, db, + assertdb.Query(t, db, `SELECT count(*) from campaigns_eventfire WHERE fired IS NOT NULL AND contact_id IN ($1) AND event_id = $2 AND fired_result = 'S'`, testdata.Alexandria.ID, testdata.RemindersEvent2.ID). Returns(1, "expected alexandria to have her event set to skipped") - testsuite.AssertQuery(t, db, + assertdb.Query(t, db, `SELECT count(*) from flows_flowsession WHERE status = 'W' AND contact_id = $1 AND session_type = 'V'`, testdata.Cathy.ID).Returns(1) } @@ -144,19 +151,18 @@ func TestBatchStart(t *testing.T) { require.NoError(t, err) assert.Equal(t, tc.Count, len(sessions), "%d: unexpected number of sessions created", i) - testsuite.AssertQuery(t, db, + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE contact_id = ANY($1) AND status = 'C' AND responded = FALSE AND org_id = 1 AND connection_id IS NULL AND output IS NOT NULL AND created_on > $2`, pq.Array(contactIDs), last). Returns(tc.Count, "%d: unexpected number of sessions", i) - testsuite.AssertQuery(t, db, + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE contact_id = ANY($1) and flow_id = $2 - AND is_active = FALSE AND responded = FALSE AND org_id = 1 AND parent_id IS NULL AND exit_type = 'C' AND status = 'C' - AND results IS NOT NULL AND path IS NOT NULL AND events IS NOT NULL - AND session_id IS NOT NULL`, pq.Array(contactIDs), tc.Flow). + AND is_active = FALSE AND responded = FALSE AND org_id = 1 AND exit_type = 'C' AND status = 'C' + AND results IS NOT NULL AND path IS NOT NULL AND session_id IS NOT NULL`, pq.Array(contactIDs), tc.Flow). Returns(tc.TotalCount, "%d: unexpected number of runs", i) - testsuite.AssertQuery(t, db, + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = ANY($1) AND text = $2 AND org_id = 1 AND status = 'Q' AND queued_on IS NOT NULL AND direction = 'O' AND topup_id IS NOT NULL AND msg_type = 'F' AND channel_id = $3`, pq.Array(contactIDs), tc.Msg, testdata.TwilioChannel.ID). @@ -169,10 +175,10 @@ func TestBatchStart(t *testing.T) { func TestResume(t *testing.T) { ctx, rt, db, _ := testsuite.Get() - defer testsuite.Reset(testsuite.ResetAll) + defer testsuite.Reset(testsuite.ResetData | testsuite.ResetStorage) // write sessions to s3 storage - db.MustExec(`UPDATE orgs_org set config = '{"session_storage_mode": "s3"}' WHERE id = 1`) + rt.Config.SessionStorage = "s3" oa, err := models.GetOrgAssetsWithRefresh(ctx, rt, testdata.Org1.ID, models.RefreshOrg) require.NoError(t, err) @@ -182,32 +188,31 @@ func TestResume(t *testing.T) { _, contact := testdata.Cathy.Load(db, oa) - trigger := triggers.NewBuilder(oa.Env(), flow.FlowReference(), contact).Manual().Build() + trigger := triggers.NewBuilder(oa.Env(), flow.Reference(), contact).Manual().Build() sessions, err := runner.StartFlowForContacts(ctx, rt, oa, flow, []flows.Trigger{trigger}, nil, true) assert.NoError(t, err) assert.NotNil(t, sessions) - testsuite.AssertQuery(t, db, + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE contact_id = $1 AND current_flow_id = $2 AND status = 'W' AND responded = FALSE AND org_id = 1 AND connection_id IS NULL AND output IS NULL`, contact.ID(), flow.ID()).Returns(1) - testsuite.AssertQuery(t, db, + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE contact_id = $1 AND flow_id = $2 AND is_active = TRUE AND responded = FALSE AND org_id = 1`, contact.ID(), flow.ID()).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND direction = 'O' AND text like '%favorite color%'`, contact.ID()).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND direction = 'O' AND text like '%favorite color%'`, contact.ID()).Returns(1) tcs := []struct { Message string - SessionStatus flows.SessionStatus + SessionStatus models.SessionStatus RunStatus models.RunStatus Substring string PathLength int - EventLength int }{ - {"Red", models.SessionStatusWaiting, models.RunStatusWaiting, "%I like Red too%", 4, 3}, - {"Mutzig", models.SessionStatusWaiting, models.RunStatusWaiting, "%they made red Mutzig%", 6, 5}, - {"Luke", models.SessionStatusCompleted, models.RunStatusCompleted, "%Thanks Luke%", 7, 7}, + {"Red", models.SessionStatusWaiting, models.RunStatusWaiting, "%I like Red too%", 4}, + {"Mutzig", models.SessionStatusWaiting, models.RunStatusWaiting, "%they made red Mutzig%", 6}, + {"Luke", models.SessionStatusCompleted, models.RunStatusCompleted, "%Thanks Luke%", 7}, } session := sessions[0] @@ -221,17 +226,16 @@ func TestResume(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, session) - testsuite.AssertQuery(t, db, - `SELECT count(*) FROM flows_flowsession WHERE contact_id = $1 AND current_flow_id = $2 - AND status = $3 AND responded = TRUE AND org_id = 1 AND connection_id IS NULL AND output IS NULL AND output_url IS NOT NULL`, contact.ID(), flow.ID(), tc.SessionStatus). + assertdb.Query(t, db, + `SELECT count(*) FROM flows_flowsession WHERE contact_id = $1 + AND status = $2 AND responded = TRUE AND org_id = 1 AND connection_id IS NULL AND output IS NULL AND output_url IS NOT NULL`, contact.ID(), tc.SessionStatus). Returns(1, "%d: didn't find expected session", i) runIsActive := tc.RunStatus == models.RunStatusActive || tc.RunStatus == models.RunStatusWaiting runQuery := `SELECT count(*) FROM flows_flowrun WHERE contact_id = $1 AND flow_id = $2 AND status = $3 AND is_active = $4 AND responded = TRUE AND org_id = 1 AND current_node_uuid IS NOT NULL - AND json_array_length(path::json) = $5 AND json_array_length(events::json) = $6 - AND session_id IS NOT NULL` + AND json_array_length(path::json) = $5 AND session_id IS NOT NULL` if runIsActive { runQuery += ` AND expires_on IS NOT NULL` @@ -239,10 +243,56 @@ func TestResume(t *testing.T) { runQuery += ` AND expires_on IS NULL` } - testsuite.AssertQuery(t, db, runQuery, contact.ID(), flow.ID(), tc.RunStatus, runIsActive, tc.PathLength, tc.EventLength). + assertdb.Query(t, db, runQuery, contact.ID(), flow.ID(), tc.RunStatus, runIsActive, tc.PathLength). Returns(1, "%d: didn't find expected run", i) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND direction = 'O' AND text like $2`, contact.ID(), tc.Substring). + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND direction = 'O' AND text like $2`, contact.ID(), tc.Substring). Returns(1, "%d: didn't find expected message", i) } } + +func TestStartFlowConcurrency(t *testing.T) { + ctx, rt, db, _ := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetData | testsuite.ResetRedis) + + // check everything works with big ids + db.MustExec(`ALTER SEQUENCE flows_flowrun_id_seq RESTART WITH 5000000000;`) + db.MustExec(`ALTER SEQUENCE flows_flowsession_id_seq RESTART WITH 5000000000;`) + + // create a flow which has a send_broadcast action which will mean handlers grabbing redis connections + flow := testdata.InsertFlow(db, testdata.Org1, testsuite.ReadFile("testdata/broadcast_flow.json")) + + oa := testdata.Org1.Load(rt) + + dbFlow, err := oa.FlowByID(flow.ID) + require.NoError(t, err) + flowRef := testdata.Favorites.Reference() + + // create a lot of contacts... + contacts := make([]*testdata.Contact, 100) + for i := range contacts { + contacts[i] = testdata.InsertContact(db, testdata.Org1, flows.ContactUUID(uuids.New()), "Jim", envs.NilLanguage) + } + + options := &runner.StartOptions{ + RestartParticipants: true, + IncludeActive: true, + TriggerBuilder: func(contact *flows.Contact) flows.Trigger { + return triggers.NewBuilder(oa.Env(), flowRef, contact).Manual().Build() + }, + CommitHook: func(ctx context.Context, tx *sqlx.Tx, rp *redis.Pool, oa *models.OrgAssets, session []*models.Session) error { + return nil + }, + } + + // start each contact in the flow at the same time... + test.RunConcurrently(len(contacts), func(i int) { + sessions, err := runner.StartFlow(ctx, rt, oa, dbFlow, []models.ContactID{contacts[i].ID}, options) + assert.NoError(t, err) + assert.Equal(t, 1, len(sessions)) + }) + + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun`).Returns(len(contacts)) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession`).Returns(len(contacts)) +} diff --git a/core/runner/testdata/broadcast_flow.json b/core/runner/testdata/broadcast_flow.json new file mode 100644 index 000000000..55bb2c78a --- /dev/null +++ b/core/runner/testdata/broadcast_flow.json @@ -0,0 +1,54 @@ +{ + "uuid": "0fad12a0-d53c-4ba0-811c-6bfde03554e2", + "name": "Broadcast Test", + "revision": 55, + "spec_version": "13.1.0", + "type": "messaging", + "expire_after_minutes": 10080, + "language": "eng", + "localization": {}, + "nodes": [ + { + "uuid": "001b4eee-812f-403e-a004-737b948b3c18", + "actions": [ + { + "uuid": "d64f25cf-8b02-4ca9-8df8-3c457ccc1090", + "type": "send_msg", + "attachments": [], + "text": "Hi there", + "quick_replies": [] + } + ], + "exits": [ + { + "uuid": "5fd2e537-0534-4c12-8425-bef87af09d46", + "destination_uuid": "788b904f-dae2-4f78-9e96-468a5b861002" + } + ] + }, + { + "uuid": "788b904f-dae2-4f78-9e96-468a5b861002", + "actions": [ + { + "uuid": "33640e44-6dc9-4aaf-b753-8bf57036cf06", + "type": "send_broadcast", + "legacy_vars": [], + "contacts": [], + "groups": [ + { + "uuid": "c153e265-f7c9-4539-9dbc-9b358714b638", + "name": "Doctors" + } + ], + "text": "This is a broadcast!" + } + ], + "exits": [ + { + "uuid": "0a1fa072-c8be-4b4c-b97a-9dad68807dbf", + "destination_uuid": null + } + ] + } + ] +} \ No newline at end of file diff --git a/core/tasks/analytics/cron.go b/core/tasks/analytics/cron.go new file mode 100644 index 000000000..8774a81f1 --- /dev/null +++ b/core/tasks/analytics/cron.go @@ -0,0 +1,98 @@ +package analytics + +import ( + "context" + "sync" + "time" + + "github.com/nyaruka/librato" + "github.com/nyaruka/mailroom" + "github.com/nyaruka/mailroom/core/queue" + "github.com/nyaruka/mailroom/runtime" + "github.com/nyaruka/mailroom/utils/cron" + "github.com/sirupsen/logrus" +) + +func init() { + mailroom.AddInitFunction(StartAnalyticsCron) +} + +// StartAnalyticsCron starts our cron job of posting stats every minute +func StartAnalyticsCron(rt *runtime.Runtime, wg *sync.WaitGroup, quit chan bool) error { + cron.Start(quit, rt, "stats", time.Second*60, true, + func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() + return reportAnalytics(ctx, rt) + }, + ) + return nil +} + +var ( + // both sqlx and redis provide wait stats which are cummulative that we need to make into increments + dbWaitDuration time.Duration + dbWaitCount int64 + redisWaitDuration time.Duration + redisWaitCount int64 +) + +// calculates a bunch of stats every minute and both logs them and sends them to librato +func reportAnalytics(ctx context.Context, rt *runtime.Runtime) error { + // We wait 15 seconds since we fire at the top of the minute, the same as expirations. + // That way any metrics related to the size of our queue are a bit more accurate (all expirations can + // usually be handled in 15 seconds). Something more complicated would take into account the age of + // the items in our queues. + time.Sleep(time.Second * 15) + + rc := rt.RP.Get() + defer rc.Close() + + // calculate size of batch queue + batchSize, err := queue.Size(rc, queue.BatchQueue) + if err != nil { + logrus.WithError(err).Error("error calculating batch queue size") + } + + // and size of handler queue + handlerSize, err := queue.Size(rc, queue.HandlerQueue) + if err != nil { + logrus.WithError(err).Error("error calculating handler queue size") + } + + // get our DB and redis stats + dbStats := rt.DB.Stats() + redisStats := rt.RP.Stats() + + dbWaitDurationInPeriod := dbStats.WaitDuration - dbWaitDuration + dbWaitCountInPeriod := dbStats.WaitCount - dbWaitCount + redisWaitDurationInPeriod := redisStats.WaitDuration - redisWaitDuration + redisWaitCountInPeriod := redisStats.WaitCount - redisWaitCount + + dbWaitDuration = dbStats.WaitDuration + dbWaitCount = dbStats.WaitCount + redisWaitDuration = redisStats.WaitDuration + redisWaitCount = redisStats.WaitCount + + librato.Gauge("mr.db_busy", float64(dbStats.InUse)) + librato.Gauge("mr.db_idle", float64(dbStats.Idle)) + librato.Gauge("mr.db_wait_ms", float64(dbWaitDurationInPeriod/time.Millisecond)) + librato.Gauge("mr.db_wait_count", float64(dbWaitCountInPeriod)) + librato.Gauge("mr.redis_wait_ms", float64(redisWaitDurationInPeriod/time.Millisecond)) + librato.Gauge("mr.redis_wait_count", float64(redisWaitCountInPeriod)) + librato.Gauge("mr.handler_queue", float64(handlerSize)) + librato.Gauge("mr.batch_queue", float64(batchSize)) + + logrus.WithFields(logrus.Fields{ + "db_busy": dbStats.InUse, + "db_idle": dbStats.Idle, + "db_wait_time": dbWaitDurationInPeriod, + "db_wait_count": dbWaitCountInPeriod, + "redis_wait_time": dbWaitDurationInPeriod, + "redis_wait_count": dbWaitCountInPeriod, + "handler_size": handlerSize, + "batch_size": batchSize, + }).Info("current analytics") + + return nil +} diff --git a/core/tasks/campaigns/cron.go b/core/tasks/campaigns/cron.go index ed82d7d2f..4eaf92bc0 100644 --- a/core/tasks/campaigns/cron.go +++ b/core/tasks/campaigns/cron.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/gomodule/redigo/redis" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/librato" "github.com/nyaruka/mailroom" @@ -13,38 +14,38 @@ import ( "github.com/nyaruka/mailroom/core/queue" "github.com/nyaruka/mailroom/runtime" "github.com/nyaruka/mailroom/utils/cron" - "github.com/nyaruka/mailroom/utils/marker" + "github.com/nyaruka/redisx" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) const ( - campaignsLock = "campaign_event" - maxBatchSize = 100 ) +var campaignsMarker = redisx.NewIntervalSet("campaign_event", time.Hour*24, 2) + func init() { mailroom.AddInitFunction(StartCampaignCron) } // StartCampaignCron starts our cron job of firing expired campaign events func StartCampaignCron(rt *runtime.Runtime, wg *sync.WaitGroup, quit chan bool) error { - cron.StartCron(quit, rt.RP, campaignsLock, time.Second*60, - func(lockName string, lockValue string) error { + cron.Start(quit, rt, "campaign_event", time.Second*60, false, + func() error { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() - return fireCampaignEvents(ctx, rt, lockName, lockValue) + return QueueEventFires(ctx, rt) }, ) return nil } -// fireCampaignEvents looks for all expired campaign event fires and queues them to be started -func fireCampaignEvents(ctx context.Context, rt *runtime.Runtime, lockName string, lockValue string) error { - log := logrus.WithField("comp", "campaign_events").WithField("lock", lockValue) +// QueueEventFires looks for all due campaign event fires and queues them to be started +func QueueEventFires(ctx context.Context, rt *runtime.Runtime) error { + log := logrus.WithField("comp", "campaign_events") start := time.Now() // find all events that need to be fired @@ -60,43 +61,9 @@ func fireCampaignEvents(ctx context.Context, rt *runtime.Runtime, lockName strin rc := rt.RP.Get() defer rc.Close() - queued := 0 - queueTask := func(orgID models.OrgID, task *FireCampaignEventTask) error { - if task.EventID == 0 { - return nil - } - - fireIDs := task.FireIDs - for len(fireIDs) > 0 { - batchSize := maxBatchSize - if batchSize > len(fireIDs) { - batchSize = len(fireIDs) - } - task.FireIDs = fireIDs[:batchSize] - fireIDs = fireIDs[batchSize:] - - err = queue.AddTask(rc, queue.BatchQueue, TypeFireCampaignEvent, int(orgID), task, queue.DefaultPriority) - if err != nil { - return errors.Wrap(err, "error queuing task") - } - - // mark each of these fires as queued - for _, id := range task.FireIDs { - err = marker.AddTask(rc, campaignsLock, fmt.Sprintf("%d", id)) - if err != nil { - return errors.Wrap(err, "error marking event as queued") - } - } - log.WithField("task", fmt.Sprintf("%vvv", task)).WithField("fire_count", len(task.FireIDs)).Debug("added event fire task") - queued += len(task.FireIDs) - } - - return nil - } - - // while we have rows orgID := models.NilOrgID - task := &FireCampaignEventTask{} + var task *FireCampaignEventTask + numFires, numDupes, numTasks := 0, 0, 0 for rows.Next() { row := &eventFireRow{} @@ -105,34 +72,40 @@ func fireCampaignEvents(ctx context.Context, rt *runtime.Runtime, lockName strin return errors.Wrapf(err, "error reading event fire row") } + numFires++ + // check whether this event has already been queued to fire taskID := fmt.Sprintf("%d", row.FireID) - dupe, err := marker.HasTask(rc, campaignsLock, taskID) + dupe, err := campaignsMarker.Contains(rc, taskID) if err != nil { return errors.Wrap(err, "error checking task lock") } - // this has already been queued, move on + // this has already been queued, skip if dupe { + numDupes++ continue } - // if this is the same event as our current task, add it there - if row.EventID == task.EventID { + // if this is the same event as our current task, and we haven't reached the fire per task limit, add it there + if task != nil && row.EventID == task.EventID && len(task.FireIDs) < maxBatchSize { task.FireIDs = append(task.FireIDs, row.FireID) continue } - // different task, queue up our current task - err = queueTask(orgID, task) - if err != nil { - return errors.Wrapf(err, "error queueing task") + // if not, queue up current task... + if task != nil { + err = queueFiresTask(rt.RP, orgID, task) + if err != nil { + return errors.Wrapf(err, "error queueing task") + } + numTasks++ } // and create a new one based on this row orgID = row.OrgID task = &FireCampaignEventTask{ - FireIDs: []int64{row.FireID}, + FireIDs: []models.FireID{row.FireID}, EventID: row.EventID, EventUUID: row.EventUUID, FlowUUID: row.FlowUUID, @@ -141,20 +114,48 @@ func fireCampaignEvents(ctx context.Context, rt *runtime.Runtime, lockName strin } } - // queue our last task - err = queueTask(orgID, task) - if err != nil { - return errors.Wrapf(err, "error queueing task") + // queue our last task if we have one + if task != nil { + if err := queueFiresTask(rt.RP, orgID, task); err != nil { + return errors.Wrapf(err, "error queueing task") + } + numTasks++ } librato.Gauge("mr.campaign_event_cron_elapsed", float64(time.Since(start))/float64(time.Second)) - librato.Gauge("mr.campaign_event_cron_count", float64(queued)) - log.WithField("elapsed", time.Since(start)).WithField("queued", queued).Info("campaign event fire queuing complete") + librato.Gauge("mr.campaign_event_cron_count", float64(numFires)) + log.WithFields(logrus.Fields{ + "elapsed": time.Since(start), + "fires": numFires, + "dupes": numDupes, + "tasks": numTasks, + }).Info("campaign event fire queuing complete") + return nil +} + +func queueFiresTask(rp *redis.Pool, orgID models.OrgID, task *FireCampaignEventTask) error { + rc := rp.Get() + defer rc.Close() + + err := queue.AddTask(rc, queue.BatchQueue, TypeFireCampaignEvent, int(orgID), task, queue.DefaultPriority) + if err != nil { + return errors.Wrap(err, "error queuing task") + } + + // mark each of these fires as queued + for _, id := range task.FireIDs { + err = campaignsMarker.Add(rc, fmt.Sprintf("%d", id)) + if err != nil { + return errors.Wrap(err, "error marking fire as queued") + } + } + + logrus.WithField("comp", "campaign_events").WithField("event", task.EventUUID).WithField("fires", len(task.FireIDs)).Debug("queued campaign event fire task") return nil } type eventFireRow struct { - FireID int64 `db:"fire_id"` + FireID models.FireID `db:"fire_id"` EventID int64 `db:"event_id"` EventUUID string `db:"event_uuid"` FlowUUID assets.FlowUUID `db:"flow_uuid"` diff --git a/core/tasks/campaigns/cron_test.go b/core/tasks/campaigns/cron_test.go index e92479718..2919c3591 100644 --- a/core/tasks/campaigns/cron_test.go +++ b/core/tasks/campaigns/cron_test.go @@ -1,13 +1,23 @@ -package campaigns +package campaigns_test import ( + "encoding/json" + "fmt" + "strconv" "testing" "time" + "github.com/gomodule/redigo/redis" + "github.com/nyaruka/gocommon/dbutil/assertdb" + "github.com/nyaruka/gocommon/jsonx" + "github.com/nyaruka/gocommon/uuids" + "github.com/nyaruka/goflow/envs" + "github.com/nyaruka/goflow/flows" _ "github.com/nyaruka/mailroom/core/handlers" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/core/queue" "github.com/nyaruka/mailroom/core/tasks" + "github.com/nyaruka/mailroom/core/tasks/campaigns" "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" @@ -15,20 +25,87 @@ import ( "github.com/stretchr/testify/require" ) -func TestCampaigns(t *testing.T) { +func TestQueueEventFires(t *testing.T) { ctx, rt, db, rp := testsuite.Get() rc := rp.Get() defer rc.Close() - defer testsuite.Reset(testsuite.ResetAll) + defer testsuite.Reset(testsuite.ResetData | testsuite.ResetRedis) + + org2Campaign := testdata.InsertCampaign(db, testdata.Org2, "Org 2", testdata.DoctorsGroup) + org2CampaignEvent := testdata.InsertCampaignFlowEvent(db, org2Campaign, testdata.Org2Favorites, testdata.AgeField, 1, "D") + + // try with zero fires + err := campaigns.QueueEventFires(ctx, rt) + assert.NoError(t, err) - // let's create a campaign event fire for one of our contacts (for now this is totally hacked, they aren't in the group and - // their relative to date isn't relative, but this still tests execution) - rt.DB.MustExec(`INSERT INTO campaigns_eventfire(scheduled, contact_id, event_id) VALUES (NOW(), $1, $3), (NOW(), $2, $3);`, testdata.Cathy.ID, testdata.George.ID, testdata.RemindersEvent1.ID) - time.Sleep(10 * time.Millisecond) + assertFireTasks(t, rp, testdata.Org1, [][]models.FireID{}) + assertFireTasks(t, rp, testdata.Org2, [][]models.FireID{}) + + // create event fires due now for 2 contacts and in the future for another contact + fire1ID := testdata.InsertEventFire(rt.DB, testdata.Cathy, testdata.RemindersEvent1, time.Now().Add(-time.Minute)) + fire2ID := testdata.InsertEventFire(rt.DB, testdata.George, testdata.RemindersEvent1, time.Now().Add(-time.Minute)) + fire3ID := testdata.InsertEventFire(rt.DB, testdata.Org2Contact, org2CampaignEvent, time.Now().Add(-time.Minute)) + fire4ID := testdata.InsertEventFire(rt.DB, testdata.Alexandria, testdata.RemindersEvent2, time.Now().Add(-time.Minute)) + testdata.InsertEventFire(rt.DB, testdata.Alexandria, testdata.RemindersEvent1, time.Now().Add(time.Hour*24)) // in future // schedule our campaign to be started - err := fireCampaignEvents(ctx, rt, campaignsLock, "lock") + err = campaigns.QueueEventFires(ctx, rt) + assert.NoError(t, err) + + assertFireTasks(t, rp, testdata.Org1, [][]models.FireID{{fire1ID, fire2ID}, {fire4ID}}) + assertFireTasks(t, rp, testdata.Org2, [][]models.FireID{{fire3ID}}) + + // running again won't double add those fires + err = campaigns.QueueEventFires(ctx, rt) + assert.NoError(t, err) + + assertFireTasks(t, rp, testdata.Org1, [][]models.FireID{{fire1ID, fire2ID}, {fire4ID}}) + assertFireTasks(t, rp, testdata.Org2, [][]models.FireID{{fire3ID}}) + + // clear queued tasks + rc.Do("DEL", "batch:active") + rc.Do("DEL", "batch:1") + + // add 110 scheduled event fires to test batch limits + for i := 0; i < 110; i++ { + contact := testdata.InsertContact(db, testdata.Org1, flows.ContactUUID(uuids.New()), fmt.Sprintf("Jim %d", i), envs.NilLanguage) + testdata.InsertEventFire(rt.DB, contact, testdata.RemindersEvent1, time.Now().Add(-time.Minute)) + } + + err = campaigns.QueueEventFires(ctx, rt) + assert.NoError(t, err) + + queuedTasks := testsuite.CurrentOrgTasks(t, rp) + org1Tasks := queuedTasks[testdata.Org1.ID] + + assert.Equal(t, 2, len(org1Tasks)) + + tk1 := struct { + FireIDs []models.FireID `json:"fire_ids"` + }{} + jsonx.MustUnmarshal(org1Tasks[0].Task, &tk1) + tk2 := struct { + FireIDs []models.FireID `json:"fire_ids"` + }{} + jsonx.MustUnmarshal(org1Tasks[1].Task, &tk2) + + assert.Equal(t, 100, len(tk1.FireIDs)) + assert.Equal(t, 10, len(tk2.FireIDs)) +} +func TestFireCampaignEvents(t *testing.T) { + ctx, rt, db, rp := testsuite.Get() + rc := rp.Get() + defer rc.Close() + + defer testsuite.Reset(testsuite.ResetData | testsuite.ResetRedis) + + // create due fires for Cathy and George + testdata.InsertEventFire(rt.DB, testdata.Cathy, testdata.RemindersEvent1, time.Now().Add(-time.Minute)) + testdata.InsertEventFire(rt.DB, testdata.George, testdata.RemindersEvent1, time.Now().Add(-time.Minute)) + + // queue the event task + err := campaigns.QueueEventFires(ctx, rt) assert.NoError(t, err) // then actually work on the event @@ -44,8 +121,8 @@ func TestCampaigns(t *testing.T) { assert.NoError(t, err) // should now have a flow run for that contact and flow - testsuite.AssertQuery(t, db, `SELECT COUNT(*) from flows_flowrun WHERE contact_id = $1 AND flow_id = $2;`, testdata.Cathy.ID, testdata.Favorites.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT COUNT(*) from flows_flowrun WHERE contact_id = $1 AND flow_id = $2;`, testdata.George.ID, testdata.Favorites.ID).Returns(1) + assertdb.Query(t, db, `SELECT COUNT(*) from flows_flowrun WHERE contact_id = $1 AND flow_id = $2;`, testdata.Cathy.ID, testdata.Favorites.ID).Returns(1) + assertdb.Query(t, db, `SELECT COUNT(*) from flows_flowrun WHERE contact_id = $1 AND flow_id = $2;`, testdata.George.ID, testdata.Favorites.ID).Returns(1) } func TestIVRCampaigns(t *testing.T) { @@ -55,14 +132,14 @@ func TestIVRCampaigns(t *testing.T) { defer testsuite.Reset(testsuite.ResetAll) - // let's create a campaign event fire for one of our contacts (for now this is totally hacked, they aren't in the group and - // their relative to date isn't relative, but this still tests execution) + // turn a campaign event into an IVR flow event rt.DB.MustExec(`UPDATE campaigns_campaignevent SET flow_id = $1 WHERE id = $2`, testdata.IVRFlow.ID, testdata.RemindersEvent1.ID) - rt.DB.MustExec(`INSERT INTO campaigns_eventfire(scheduled, contact_id, event_id) VALUES (NOW(), $1, $3), (NOW(), $2, $3);`, testdata.Cathy.ID, testdata.George.ID, testdata.RemindersEvent1.ID) - time.Sleep(10 * time.Millisecond) + + testdata.InsertEventFire(rt.DB, testdata.Cathy, testdata.RemindersEvent1, time.Now().Add(-time.Minute)) + testdata.InsertEventFire(rt.DB, testdata.George, testdata.RemindersEvent1, time.Now().Add(-time.Minute)) // schedule our campaign to be started - err := fireCampaignEvents(ctx, rt, campaignsLock, "lock") + err := campaigns.QueueEventFires(ctx, rt) assert.NoError(t, err) // then actually work on the event @@ -78,12 +155,12 @@ func TestIVRCampaigns(t *testing.T) { assert.NoError(t, err) // should now have a flow start created - testsuite.AssertQuery(t, db, `SELECT COUNT(*) from flows_flowstart WHERE flow_id = $1 AND start_type = 'T' AND status = 'P';`, testdata.IVRFlow.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT COUNT(*) from flows_flowstart_contacts WHERE contact_id = $1 AND flowstart_id = 1;`, testdata.Cathy.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT COUNT(*) from flows_flowstart_contacts WHERE contact_id = $1 AND flowstart_id = 1;`, testdata.George.ID).Returns(1) + assertdb.Query(t, db, `SELECT COUNT(*) from flows_flowstart WHERE flow_id = $1 AND start_type = 'T' AND status = 'P';`, testdata.IVRFlow.ID).Returns(1) + assertdb.Query(t, db, `SELECT COUNT(*) from flows_flowstart_contacts WHERE contact_id = $1 AND flowstart_id = 1;`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, db, `SELECT COUNT(*) from flows_flowstart_contacts WHERE contact_id = $1 AND flowstart_id = 1;`, testdata.George.ID).Returns(1) // event should be marked as fired - testsuite.AssertQuery(t, db, `SELECT COUNT(*) from campaigns_eventfire WHERE event_id = $1 AND fired IS NOT NULL;`, testdata.RemindersEvent1.ID).Returns(2) + assertdb.Query(t, db, `SELECT COUNT(*) from campaigns_eventfire WHERE event_id = $1 AND fired IS NOT NULL;`, testdata.RemindersEvent1.ID).Returns(2) // pop our next task, should be the start task, err = queue.PopNextTask(rc, queue.BatchQueue) @@ -92,3 +169,23 @@ func TestIVRCampaigns(t *testing.T) { assert.Equal(t, task.Type, queue.StartIVRFlowBatch) } + +func assertFireTasks(t *testing.T, rp *redis.Pool, org *testdata.Org, expected [][]models.FireID) { + allTasks := testsuite.CurrentOrgTasks(t, rp) + actual := make([][]models.FireID, len(allTasks[org.ID])) + + for i, task := range allTasks[org.ID] { + payload, err := jsonx.DecodeGeneric(task.Task) + require.NoError(t, err) + + taskFireInts := payload.(map[string]interface{})["fire_ids"].([]interface{}) + taskFireIDs := make([]models.FireID, len(taskFireInts)) + for i := range taskFireInts { + id, _ := strconv.Atoi(string(taskFireInts[i].(json.Number))) + taskFireIDs[i] = models.FireID(int64(id)) + } + actual[i] = taskFireIDs + } + + assert.ElementsMatch(t, expected, actual) +} diff --git a/core/tasks/campaigns/fire_campaign_event.go b/core/tasks/campaigns/fire_campaign_event.go index a034cf85c..19ab9d8b5 100644 --- a/core/tasks/campaigns/fire_campaign_event.go +++ b/core/tasks/campaigns/fire_campaign_event.go @@ -11,7 +11,6 @@ import ( "github.com/nyaruka/mailroom/core/runner" "github.com/nyaruka/mailroom/core/tasks" "github.com/nyaruka/mailroom/runtime" - "github.com/nyaruka/mailroom/utils/marker" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -26,7 +25,7 @@ func init() { // FireCampaignEventTask is the task to handle firing campaign events type FireCampaignEventTask struct { - FireIDs []int64 `json:"fire_ids"` + FireIDs []models.FireID `json:"fire_ids"` EventID int64 `json:"event_id"` EventUUID string `json:"event_uuid"` FlowUUID assets.FlowUUID `json:"flow_uuid"` @@ -58,7 +57,7 @@ func (t *FireCampaignEventTask) Perform(ctx context.Context, rt *runtime.Runtime // unmark all these fires as fires so they can retry rc := rp.Get() for _, id := range t.FireIDs { - rerr := marker.RemoveTask(rc, campaignsLock, fmt.Sprintf("%d", id)) + rerr := campaignsMarker.Remove(rc, fmt.Sprintf("%d", id)) if rerr != nil { log.WithError(rerr).WithField("fire_id", id).Error("error unmarking campaign fire") } @@ -93,7 +92,7 @@ func (t *FireCampaignEventTask) Perform(ctx context.Context, rt *runtime.Runtime if len(contactMap) > 0 { rc := rp.Get() for _, failed := range contactMap { - rerr := marker.RemoveTask(rc, campaignsLock, fmt.Sprintf("%d", failed.FireID)) + rerr := campaignsMarker.Remove(rc, fmt.Sprintf("%d", failed.FireID)) if rerr != nil { log.WithError(rerr).WithField("fire_id", failed.FireID).Error("error unmarking campaign fire") } diff --git a/core/tasks/campaigns/schedule_campaign_event.go b/core/tasks/campaigns/schedule_campaign_event.go index 2789eecd5..31df1e1d1 100644 --- a/core/tasks/campaigns/schedule_campaign_event.go +++ b/core/tasks/campaigns/schedule_campaign_event.go @@ -8,15 +8,14 @@ import ( "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/core/tasks" "github.com/nyaruka/mailroom/runtime" - "github.com/nyaruka/mailroom/utils/locker" - + "github.com/nyaruka/redisx" "github.com/pkg/errors" ) // TypeScheduleCampaignEvent is the type of the schedule event task const TypeScheduleCampaignEvent = "schedule_campaign_event" -const scheduleLockKey string = "schedule_campaign_event_%d" +const scheduleLockKey string = "lock:schedule_campaign_event_%d" func init() { tasks.RegisterType(TypeScheduleCampaignEvent, func() tasks.Task { return &ScheduleCampaignEventTask{} }) @@ -34,14 +33,12 @@ func (t *ScheduleCampaignEventTask) Timeout() time.Duration { // Perform creates the actual event fires to schedule the given campaign event func (t *ScheduleCampaignEventTask) Perform(ctx context.Context, rt *runtime.Runtime, orgID models.OrgID) error { - rp := rt.RP - lockKey := fmt.Sprintf(scheduleLockKey, t.CampaignEventID) - - lock, err := locker.GrabLock(rp, lockKey, time.Hour, time.Minute*5) + locker := redisx.NewLocker(fmt.Sprintf(scheduleLockKey, t.CampaignEventID), time.Hour) + lock, err := locker.Grab(rt.RP, time.Minute*5) if err != nil { return errors.Wrapf(err, "error grabbing lock to schedule campaign event %d", t.CampaignEventID) } - defer locker.ReleaseLock(rp, lockKey, lock) + defer locker.Release(rt.RP, lock) err = models.ScheduleCampaignEvent(ctx, rt, orgID, t.CampaignEventID) if err != nil { diff --git a/core/tasks/campaigns/schedule_campaign_event_test.go b/core/tasks/campaigns/schedule_campaign_event_test.go index 29c85343c..a04d4f3fb 100644 --- a/core/tasks/campaigns/schedule_campaign_event_test.go +++ b/core/tasks/campaigns/schedule_campaign_event_test.go @@ -5,7 +5,6 @@ import ( "time" "github.com/jmoiron/sqlx" - "github.com/nyaruka/gocommon/uuids" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/core/tasks/campaigns" "github.com/nyaruka/mailroom/testsuite" @@ -70,42 +69,32 @@ func TestScheduleCampaignEvent(t *testing.T) { db.MustExec(`UPDATE contacts_contact SET created_on = '2035-01-01T00:00:00Z' WHERE id = $1 OR id = $2`, testdata.Cathy.ID, testdata.Alexandria.ID) // create new campaign event based on created_on + 5 minutes - event3 := insertCampaignEvent(t, db, testdata.RemindersCampaign.ID, testdata.Favorites.ID, testdata.CreatedOnField.ID, 5, "M") + event3 := testdata.InsertCampaignFlowEvent(db, testdata.RemindersCampaign, testdata.Favorites, testdata.CreatedOnField, 5, "M") - task = &campaigns.ScheduleCampaignEventTask{CampaignEventID: event3} + task = &campaigns.ScheduleCampaignEventTask{CampaignEventID: event3.ID} err = task.Perform(ctx, rt, testdata.Org1.ID) require.NoError(t, err) // only cathy is in the group and new enough to have a fire - assertContactFires(t, db, event3, map[models.ContactID]time.Time{ + assertContactFires(t, db, event3.ID, map[models.ContactID]time.Time{ testdata.Cathy.ID: time.Date(2035, 1, 1, 0, 5, 0, 0, time.UTC), }) // create new campaign event based on last_seen_on + 1 day - event4 := insertCampaignEvent(t, db, testdata.RemindersCampaign.ID, testdata.Favorites.ID, testdata.LastSeenOnField.ID, 1, "D") + event4 := testdata.InsertCampaignFlowEvent(db, testdata.RemindersCampaign, testdata.Favorites, testdata.LastSeenOnField, 1, "D") // bump last_seen_on for bob db.MustExec(`UPDATE contacts_contact SET last_seen_on = '2040-01-01T00:00:00Z' WHERE id = $1`, testdata.Bob.ID) - task = &campaigns.ScheduleCampaignEventTask{CampaignEventID: event4} + task = &campaigns.ScheduleCampaignEventTask{CampaignEventID: event4.ID} err = task.Perform(ctx, rt, testdata.Org1.ID) require.NoError(t, err) - assertContactFires(t, db, event4, map[models.ContactID]time.Time{ + assertContactFires(t, db, event4.ID, map[models.ContactID]time.Time{ testdata.Bob.ID: time.Date(2040, 1, 2, 0, 0, 0, 0, time.UTC), }) } -func insertCampaignEvent(t *testing.T, db *sqlx.DB, campaignID models.CampaignID, flowID models.FlowID, relativeToID models.FieldID, offset int, unit string) models.CampaignEventID { - var eventID models.CampaignEventID - err := db.Get(&eventID, ` - INSERT INTO campaigns_campaignevent(is_active, created_on, modified_on, uuid, "offset", unit, event_type, delivery_hour, campaign_id, created_by_id, modified_by_id, flow_id, relative_to_id, start_mode) - VALUES(TRUE, NOW(), NOW(), $1, $5, $6, 'F', -1, $2, 1, 1, $3, $4, 'I') RETURNING id`, uuids.New(), campaignID, flowID, relativeToID, offset, unit) - require.NoError(t, err) - - return eventID -} - func assertContactFires(t *testing.T, db *sqlx.DB, eventID models.CampaignEventID, expected map[models.ContactID]time.Time) { type idAndTime struct { ContactID models.ContactID `db:"contact_id"` diff --git a/core/tasks/contacts/import_contact_batch_test.go b/core/tasks/contacts/import_contact_batch_test.go index f50e23e14..cce2d2b60 100644 --- a/core/tasks/contacts/import_contact_batch_test.go +++ b/core/tasks/contacts/import_contact_batch_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + "github.com/nyaruka/gocommon/dbutil/assertdb" _ "github.com/nyaruka/mailroom/core/handlers" "github.com/nyaruka/mailroom/core/tasks/contacts" "github.com/nyaruka/mailroom/testsuite" @@ -36,21 +37,21 @@ func TestImportContactBatch(t *testing.T) { require.NoError(t, err) // import is still in progress - testsuite.AssertQuery(t, db, `SELECT status FROM contacts_contactimport WHERE id = $1`, importID).Columns(map[string]interface{}{"status": "O"}) + assertdb.Query(t, db, `SELECT status FROM contacts_contactimport WHERE id = $1`, importID).Columns(map[string]interface{}{"status": "O"}) // perform second batch task... task2 := &contacts.ImportContactBatchTask{ContactImportBatchID: batch2ID} err = task2.Perform(ctx, rt, testdata.Org1.ID) require.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE id >= 30000`).Returns(3) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE name = 'Norbert' AND language = 'eng'`).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE name = 'Leah' AND language IS NULL`).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE name = 'Rowan' AND language = 'spa'`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE id >= 30000`).Returns(3) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE name = 'Norbert' AND language = 'eng'`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE name = 'Leah' AND language IS NULL`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE name = 'Rowan' AND language = 'spa'`).Returns(1) // import is now complete and there is a notification for the creator - testsuite.AssertQuery(t, db, `SELECT status FROM contacts_contactimport WHERE id = $1`, importID).Columns(map[string]interface{}{"status": "C"}) - testsuite.AssertQuery(t, db, `SELECT org_id, notification_type, scope, user_id FROM notifications_notification WHERE contact_import_id = $1`, importID). + assertdb.Query(t, db, `SELECT status FROM contacts_contactimport WHERE id = $1`, importID).Columns(map[string]interface{}{"status": "C"}) + assertdb.Query(t, db, `SELECT org_id, notification_type, scope, user_id FROM notifications_notification WHERE contact_import_id = $1`, importID). Columns(map[string]interface{}{ "org_id": int64(testdata.Org1.ID), "notification_type": "import:finished", diff --git a/core/tasks/contacts/populate_dynamic_group.go b/core/tasks/contacts/populate_dynamic_group.go index 2e8c78bc8..0ee825dd3 100644 --- a/core/tasks/contacts/populate_dynamic_group.go +++ b/core/tasks/contacts/populate_dynamic_group.go @@ -8,7 +8,7 @@ import ( "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/core/tasks" "github.com/nyaruka/mailroom/runtime" - "github.com/nyaruka/mailroom/utils/locker" + "github.com/nyaruka/redisx" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -17,7 +17,7 @@ import ( // TypePopulateDynamicGroup is the type of the populate group task const TypePopulateDynamicGroup = "populate_dynamic_group" -const populateLockKey string = "pop_dyn_group_%d" +const populateLockKey string = "lock:pop_dyn_group_%d" func init() { tasks.RegisterType(TypePopulateDynamicGroup, func() tasks.Task { return &PopulateDynamicGroupTask{} }) @@ -36,12 +36,12 @@ func (t *PopulateDynamicGroupTask) Timeout() time.Duration { // Perform figures out the membership for a query based group then repopulates it func (t *PopulateDynamicGroupTask) Perform(ctx context.Context, rt *runtime.Runtime, orgID models.OrgID) error { - lockKey := fmt.Sprintf(populateLockKey, t.GroupID) - lock, err := locker.GrabLock(rt.RP, lockKey, time.Hour, time.Minute*5) + locker := redisx.NewLocker(fmt.Sprintf(populateLockKey, t.GroupID), time.Hour) + lock, err := locker.Grab(rt.RP, time.Minute*5) if err != nil { return errors.Wrapf(err, "error grabbing lock to repopulate dynamic group: %d", t.GroupID) } - defer locker.ReleaseLock(rt.RP, lockKey, lock) + defer locker.Release(rt.RP, lock) start := time.Now() log := logrus.WithFields(logrus.Fields{ diff --git a/core/tasks/contacts/populate_dynamic_group_test.go b/core/tasks/contacts/populate_dynamic_group_test.go index 722a26dee..5398f0167 100644 --- a/core/tasks/contacts/populate_dynamic_group_test.go +++ b/core/tasks/contacts/populate_dynamic_group_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/nyaruka/gocommon/dates" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/mailroom/core/tasks/contacts" "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" @@ -64,7 +65,7 @@ func TestPopulateTask(t *testing.T) { err = task.Perform(ctx, rt, testdata.Org1.ID) require.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contactgroup_contacts WHERE contactgroup_id = $1`, group.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT contact_id FROM contacts_contactgroup_contacts WHERE contactgroup_id = $1`, group.ID).Returns(int64(testdata.Cathy.ID)) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND modified_on > $2`, testdata.Cathy.ID, start).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contactgroup_contacts WHERE contactgroup_id = $1`, group.ID).Returns(1) + assertdb.Query(t, db, `SELECT contact_id FROM contacts_contactgroup_contacts WHERE contactgroup_id = $1`, group.ID).Returns(int64(testdata.Cathy.ID)) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND modified_on > $2`, testdata.Cathy.ID, start).Returns(1) } diff --git a/core/tasks/expirations/cron.go b/core/tasks/expirations/cron.go index 11e5e1141..9393841c0 100644 --- a/core/tasks/expirations/cron.go +++ b/core/tasks/expirations/cron.go @@ -12,7 +12,7 @@ import ( "github.com/nyaruka/mailroom/core/tasks/handler" "github.com/nyaruka/mailroom/runtime" "github.com/nyaruka/mailroom/utils/cron" - "github.com/nyaruka/mailroom/utils/marker" + "github.com/nyaruka/redisx" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -20,36 +20,36 @@ import ( const ( expirationLock = "run_expirations" - markerGroup = "run_expirations" expireBatchSize = 500 ) +var expirationsMarker = redisx.NewIntervalSet("run_expirations", time.Hour*24, 2) + func init() { mailroom.AddInitFunction(StartExpirationCron) } // StartExpirationCron starts our cron job of expiring runs every minute func StartExpirationCron(rt *runtime.Runtime, wg *sync.WaitGroup, quit chan bool) error { - cron.StartCron(quit, rt.RP, expirationLock, time.Second*60, - func(lockName string, lockValue string) error { + cron.Start(quit, rt, expirationLock, time.Second*60, false, + func() error { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() - return expireRuns(ctx, rt, lockName, lockValue) + return expireRuns(ctx, rt) }, ) return nil } // expireRuns expires all the runs that have an expiration in the past -func expireRuns(ctx context.Context, rt *runtime.Runtime, lockName string, lockValue string) error { - log := logrus.WithField("comp", "expirer").WithField("lock", lockValue) +func expireRuns(ctx context.Context, rt *runtime.Runtime) error { + log := logrus.WithField("comp", "expirer") start := time.Now() rc := rt.RP.Get() defer rc.Close() - // we expire runs and sessions that have no continuation in batches - expiredRuns := make([]models.FlowRunID, 0, expireBatchSize) + // we expire sessions that have no continuation in batches expiredSessions := make([]models.SessionID, 0, expireBatchSize) // select our expired runs @@ -69,9 +69,8 @@ func expireRuns(ctx context.Context, rt *runtime.Runtime, lockName string, lockV count++ - // no parent id? we can add this to our batch - if expiration.ParentUUID == nil || expiration.SessionID == nil { - expiredRuns = append(expiredRuns, expiration.RunID) + // no parent? we can add this to our batch + if expiration.ParentUUID == nil { if expiration.SessionID != nil { expiredSessions = append(expiredSessions, *expiration.SessionID) @@ -80,12 +79,11 @@ func expireRuns(ctx context.Context, rt *runtime.Runtime, lockName string, lockV } // batch is full? commit it - if len(expiredRuns) == expireBatchSize { - err = models.ExpireRunsAndSessions(ctx, rt.DB, expiredRuns, expiredSessions) + if len(expiredSessions) == expireBatchSize { + err = models.ExitSessions(ctx, rt.DB, expiredSessions, models.SessionStatusExpired) if err != nil { return errors.Wrapf(err, "error expiring runs and sessions") } - expiredRuns = expiredRuns[:0] expiredSessions = expiredSessions[:0] } @@ -94,7 +92,7 @@ func expireRuns(ctx context.Context, rt *runtime.Runtime, lockName string, lockV // need to continue this session and flow, create a task for that taskID := fmt.Sprintf("%d:%s", expiration.RunID, expiration.ExpiresOn.Format(time.RFC3339)) - queued, err := marker.HasTask(rc, markerGroup, taskID) + queued, err := expirationsMarker.Contains(rc, taskID) if err != nil { return errors.Wrapf(err, "error checking whether expiration is queued") } @@ -112,15 +110,15 @@ func expireRuns(ctx context.Context, rt *runtime.Runtime, lockName string, lockV } // and mark it as queued - err = marker.AddTask(rc, markerGroup, taskID) + err = expirationsMarker.Add(rc, taskID) if err != nil { return errors.Wrapf(err, "error marking expiration task as queued") } } // commit any stragglers - if len(expiredRuns) > 0 { - err = models.ExpireRunsAndSessions(ctx, rt.DB, expiredRuns, expiredSessions) + if len(expiredSessions) > 0 { + err = models.ExitSessions(ctx, rt.DB, expiredSessions, models.SessionStatusExpired) if err != nil { return errors.Wrapf(err, "error expiring runs and sessions") } diff --git a/core/tasks/expirations/cron_test.go b/core/tasks/expirations/cron_test.go index ba7e52a9f..4d6b7892f 100644 --- a/core/tasks/expirations/cron_test.go +++ b/core/tasks/expirations/cron_test.go @@ -5,14 +5,13 @@ import ( "testing" "time" + "github.com/nyaruka/gocommon/dbutil/assertdb" _ "github.com/nyaruka/mailroom/core/handlers" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/core/queue" "github.com/nyaruka/mailroom/core/tasks/handler" "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" - "github.com/nyaruka/mailroom/utils/marker" - "github.com/stretchr/testify/assert" ) @@ -21,15 +20,12 @@ func TestExpirations(t *testing.T) { rc := rp.Get() defer rc.Close() - defer testsuite.Reset(testsuite.ResetAll) - - err := marker.ClearTasks(rc, expirationLock) - assert.NoError(t, err) + defer testsuite.Reset(testsuite.ResetData | testsuite.ResetRedis) // create a few sessions - s1 := testdata.InsertFlowSession(db, testdata.Org1, testdata.Cathy, models.SessionStatusWaiting, nil) - s2 := testdata.InsertFlowSession(db, testdata.Org1, testdata.George, models.SessionStatusWaiting, nil) - s3 := testdata.InsertFlowSession(db, testdata.Org1, testdata.Bob, models.SessionStatusWaiting, nil) + s1 := testdata.InsertFlowSession(db, testdata.Org1, testdata.Cathy, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, models.NilConnectionID, nil) + s2 := testdata.InsertFlowSession(db, testdata.Org1, testdata.George, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, models.NilConnectionID, nil) + s3 := testdata.InsertFlowSession(db, testdata.Org1, testdata.Bob, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, models.NilConnectionID, nil) // simple run, no parent r1ExpiresOn := time.Now() @@ -52,30 +48,30 @@ func TestExpirations(t *testing.T) { time.Sleep(10 * time.Millisecond) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowrun WHERE is_active = TRUE AND contact_id = $1;`, testdata.Cathy.ID).Returns(2) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowsession WHERE status = 'X' AND contact_id = $1;`, testdata.Cathy.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE is_active = TRUE AND contact_id = $1;`, testdata.Cathy.ID).Returns(2) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE status = 'X' AND contact_id = $1;`, testdata.Cathy.ID).Returns(0) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowrun WHERE is_active = TRUE AND contact_id = $1;`, testdata.George.ID).Returns(2) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowsession WHERE status = 'X' AND contact_id = $1;`, testdata.George.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE is_active = TRUE AND contact_id = $1;`, testdata.George.ID).Returns(2) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE status = 'X' AND contact_id = $1;`, testdata.George.ID).Returns(0) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowrun WHERE is_active = TRUE AND contact_id = $1;`, testdata.Bob.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowsession WHERE status = 'X' AND contact_id = $1;`, testdata.Bob.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE is_active = TRUE AND contact_id = $1;`, testdata.Bob.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE status = 'X' AND contact_id = $1;`, testdata.Bob.ID).Returns(0) // expire our runs - err = expireRuns(ctx, rt, expirationLock, "foo") + err := expireRuns(ctx, rt) assert.NoError(t, err) - // shouldn't have any active runs or sessions - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowrun WHERE is_active = TRUE AND contact_id = $1;`, testdata.Cathy.ID).Returns(0) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowsession WHERE status = 'X' AND contact_id = $1;`, testdata.Cathy.ID).Returns(1) + // shouldn't have any active runs or sessions (except the sessionless run) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE is_active = TRUE AND contact_id = $1;`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE status = 'X' AND contact_id = $1;`, testdata.Cathy.ID).Returns(1) // should still have two active runs for George as it needs to continue - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowrun WHERE is_active = TRUE AND contact_id = $1;`, testdata.George.ID).Returns(2) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowsession WHERE status = 'X' AND contact_id = $1;`, testdata.George.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE is_active = TRUE AND contact_id = $1;`, testdata.George.ID).Returns(2) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE status = 'X' AND contact_id = $1;`, testdata.George.ID).Returns(0) // runs without expires_on won't be expired - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowrun WHERE is_active = TRUE AND contact_id = $1;`, testdata.Bob.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowsession WHERE status = 'X' AND contact_id = $1;`, testdata.Bob.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE is_active = TRUE AND contact_id = $1;`, testdata.Bob.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE status = 'X' AND contact_id = $1;`, testdata.Bob.ID).Returns(0) // should have created one task task, err := queue.PopNextTask(rc, queue.HandlerQueue) diff --git a/core/tasks/handler/cron.go b/core/tasks/handler/cron.go index fae111e24..944bd0c92 100644 --- a/core/tasks/handler/cron.go +++ b/core/tasks/handler/cron.go @@ -12,7 +12,7 @@ import ( "github.com/nyaruka/mailroom/core/queue" "github.com/nyaruka/mailroom/runtime" "github.com/nyaruka/mailroom/utils/cron" - "github.com/nyaruka/mailroom/utils/marker" + "github.com/nyaruka/redisx" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -20,32 +20,33 @@ import ( const ( retryLock = "retry_msgs" - markerKey = "retried_msgs" ) +var retriedMsgs = redisx.NewIntervalSet("retried_msgs", time.Hour*24, 2) + func init() { mailroom.AddInitFunction(StartRetryCron) } // StartRetryCron starts our cron job of retrying pending incoming messages func StartRetryCron(rt *runtime.Runtime, wg *sync.WaitGroup, quit chan bool) error { - cron.StartCron(quit, rt.RP, retryLock, time.Minute*5, - func(lockName string, lockValue string) error { + cron.Start(quit, rt, retryLock, time.Minute*5, false, + func() error { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() - return RetryPendingMsgs(ctx, rt, lockName, lockValue) + return RetryPendingMsgs(ctx, rt) }, ) return nil } // RetryPendingMsgs looks for any pending msgs older than five minutes and queues them to be handled again -func RetryPendingMsgs(ctx context.Context, rt *runtime.Runtime, lockName string, lockValue string) error { +func RetryPendingMsgs(ctx context.Context, rt *runtime.Runtime) error { if !rt.Config.RetryPendingMessages { return nil } - log := logrus.WithField("comp", "handler_retrier").WithField("lock", lockValue) + log := logrus.WithField("comp", "handler_retrier") start := time.Now() rc := rt.RP.Get() @@ -85,7 +86,7 @@ func RetryPendingMsgs(ctx context.Context, rt *runtime.Runtime, lockName string, // our key is built such that we will only retry once an hour key := fmt.Sprintf("%d_%d", msgID, time.Now().Hour()) - dupe, err := marker.HasTask(rc, markerKey, key) + dupe, err := retriedMsgs.Contains(rc, key) if err != nil { return errors.Wrapf(err, "error checking for dupe retry") } @@ -109,7 +110,7 @@ func RetryPendingMsgs(ctx context.Context, rt *runtime.Runtime, lockName string, } // mark it as queued - err = marker.AddTask(rc, markerKey, key) + err = retriedMsgs.Add(rc, key) if err != nil { return errors.Wrapf(err, "error marking task for retry") } diff --git a/core/tasks/handler/cron_test.go b/core/tasks/handler/cron_test.go index f8bcb245b..ebca11d80 100644 --- a/core/tasks/handler/cron_test.go +++ b/core/tasks/handler/cron_test.go @@ -4,15 +4,14 @@ import ( "testing" "time" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/uuids" - _ "github.com/nyaruka/mailroom/core/handlers" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/core/queue" "github.com/nyaruka/mailroom/core/tasks/handler" "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" - "github.com/stretchr/testify/assert" ) @@ -24,7 +23,7 @@ func TestRetryMsgs(t *testing.T) { defer testsuite.Reset(testsuite.ResetAll) // noop does nothing - err := handler.RetryPendingMsgs(ctx, rt, "test", "test") + err := handler.RetryPendingMsgs(ctx, rt) assert.NoError(t, err) testMsgs := []struct { @@ -44,7 +43,7 @@ func TestRetryMsgs(t *testing.T) { uuids.New(), testdata.Org1.ID, testdata.TwilioChannel.ID, testdata.Cathy.ID, testdata.Cathy.URNID, msg.Text, models.DirectionIn, msg.Status, msg.CreatedOn) } - err = handler.RetryPendingMsgs(ctx, rt, "test", "test") + err = handler.RetryPendingMsgs(ctx, rt) assert.NoError(t, err) // should have one message requeued @@ -54,7 +53,7 @@ func TestRetryMsgs(t *testing.T) { assert.NoError(t, err) // message should be handled now - testsuite.AssertQuery(t, db, `SELECT count(*) from msgs_msg WHERE text = 'pending' AND status = 'H'`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) from msgs_msg WHERE text = 'pending' AND status = 'H'`).Returns(1) // only one message was queued task, _ = queue.PopNextTask(rc, queue.HandlerQueue) diff --git a/core/tasks/handler/handler_test.go b/core/tasks/handler/handler_test.go index e7ac6af4a..aed9fcbd5 100644 --- a/core/tasks/handler/handler_test.go +++ b/core/tasks/handler/handler_test.go @@ -6,6 +6,8 @@ import ( "testing" "time" + "github.com/nyaruka/gocommon/dbutil/assertdb" + "github.com/nyaruka/gocommon/jsonx" "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/gocommon/uuids" "github.com/nyaruka/goflow/flows" @@ -60,49 +62,201 @@ func TestMsgEvents(t *testing.T) { dbMsg := testdata.InsertIncomingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "", models.MsgStatusPending) tcs := []struct { - Hook func() - Org *testdata.Org - Channel *testdata.Channel - Contact *testdata.Contact - Text string - ExpectedReply string - ExpectedType models.MsgType + preHook func() + org *testdata.Org + channel *testdata.Channel + contact *testdata.Contact + text string + expectedReply string + expectedType models.MsgType + expectedFlow *testdata.Flow }{ - {nil, testdata.Org1, testdata.TwitterChannel, testdata.Cathy, "noop", "", models.MsgTypeInbox}, - {nil, testdata.Org1, testdata.TwitterChannel, testdata.Cathy, "start other", "", models.MsgTypeInbox}, - {nil, testdata.Org1, testdata.TwitterChannel, testdata.Cathy, "start", "What is your favorite color?", models.MsgTypeFlow}, - {nil, testdata.Org1, testdata.TwitterChannel, testdata.Cathy, "purple", "I don't know that color. Try again.", models.MsgTypeFlow}, - {nil, testdata.Org1, testdata.TwitterChannel, testdata.Cathy, "blue", "Good choice, I like Blue too! What is your favorite beer?", models.MsgTypeFlow}, - {nil, testdata.Org1, testdata.TwitterChannel, testdata.Cathy, "MUTZIG", "Mmmmm... delicious Mutzig. If only they made blue Mutzig! Lastly, what is your name?", models.MsgTypeFlow}, - {nil, testdata.Org1, testdata.TwitterChannel, testdata.Cathy, "Cathy", "Thanks Cathy, we are all done!", models.MsgTypeFlow}, - {nil, testdata.Org1, testdata.TwitterChannel, testdata.Cathy, "noop", "", models.MsgTypeInbox}, - - {nil, testdata.Org2, testdata.Org2Channel, testdata.Org2Contact, "other", "Hey, how are you?", models.MsgTypeFlow}, - {nil, testdata.Org2, testdata.Org2Channel, testdata.Org2Contact, "start", "What is your favorite color?", models.MsgTypeFlow}, - {nil, testdata.Org2, testdata.Org2Channel, testdata.Org2Contact, "green", "Good choice, I like Green too! What is your favorite beer?", models.MsgTypeFlow}, - {nil, testdata.Org2, testdata.Org2Channel, testdata.Org2Contact, "primus", "Mmmmm... delicious Primus. If only they made green Primus! Lastly, what is your name?", models.MsgTypeFlow}, - {nil, testdata.Org2, testdata.Org2Channel, testdata.Org2Contact, "george", "Thanks george, we are all done!", models.MsgTypeFlow}, - {nil, testdata.Org2, testdata.Org2Channel, testdata.Org2Contact, "blargh", "Hey, how are you?", models.MsgTypeFlow}, - - {nil, testdata.Org1, testdata.TwitterChannel, testdata.Bob, "ivr", "", models.MsgTypeFlow}, + { + org: testdata.Org1, + channel: testdata.TwitterChannel, + contact: testdata.Cathy, + text: "noop", + expectedReply: "", + expectedType: models.MsgTypeInbox, + }, + { + org: testdata.Org1, + channel: testdata.TwitterChannel, + contact: testdata.Cathy, + text: "start other", + expectedReply: "", + expectedType: models.MsgTypeInbox, + }, + { + org: testdata.Org1, + channel: testdata.TwitterChannel, + contact: testdata.Cathy, + text: "start", + expectedReply: "What is your favorite color?", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Favorites, + }, + { + org: testdata.Org1, + channel: testdata.TwitterChannel, + contact: testdata.Cathy, + text: "purple", + expectedReply: "I don't know that color. Try again.", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Favorites, + }, + { + org: testdata.Org1, + channel: testdata.TwitterChannel, + contact: testdata.Cathy, + text: "blue", + expectedReply: "Good choice, I like Blue too! What is your favorite beer?", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Favorites, + }, + { + org: testdata.Org1, + channel: testdata.TwitterChannel, + contact: testdata.Cathy, + text: "MUTZIG", + expectedReply: "Mmmmm... delicious Mutzig. If only they made blue Mutzig! Lastly, what is your name?", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Favorites, + }, + { + org: testdata.Org1, + channel: testdata.TwitterChannel, + contact: testdata.Cathy, + text: "Cathy", + expectedReply: "Thanks Cathy, we are all done!", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Favorites, + }, + { + org: testdata.Org1, + channel: testdata.TwitterChannel, + contact: testdata.Cathy, + text: "noop", + expectedReply: "", + expectedType: models.MsgTypeInbox, + }, + + { + org: testdata.Org2, + channel: testdata.Org2Channel, + contact: testdata.Org2Contact, + text: "other", + expectedReply: "Hey, how are you?", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Org2SingleMessage, + }, + { + org: testdata.Org2, + channel: testdata.Org2Channel, + contact: testdata.Org2Contact, + text: "start", + expectedReply: "What is your favorite color?", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Org2Favorites, + }, + { + org: testdata.Org2, + channel: testdata.Org2Channel, + contact: testdata.Org2Contact, + text: "green", + expectedReply: "Good choice, I like Green too! What is your favorite beer?", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Org2Favorites, + }, + { + org: testdata.Org2, + channel: testdata.Org2Channel, + contact: testdata.Org2Contact, + text: "primus", + expectedReply: "Mmmmm... delicious Primus. If only they made green Primus! Lastly, what is your name?", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Org2Favorites, + }, + { + org: testdata.Org2, + channel: testdata.Org2Channel, + contact: testdata.Org2Contact, + text: "george", + expectedReply: "Thanks george, we are all done!", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Org2Favorites, + }, + { + org: testdata.Org2, + channel: testdata.Org2Channel, + contact: testdata.Org2Contact, + text: "blargh", + expectedReply: "Hey, how are you?", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Org2SingleMessage, + }, + + { + org: testdata.Org1, + channel: testdata.TwitterChannel, + contact: testdata.Bob, + text: "ivr", + expectedReply: "", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.IVRFlow, + }, // no URN on contact but handle event, session gets started but no message created - {nil, testdata.Org1, testdata.TwilioChannel, testdata.Alexandria, "start", "", models.MsgTypeFlow}, + { + org: testdata.Org1, + channel: testdata.TwilioChannel, + contact: testdata.Alexandria, + text: "start", + expectedReply: "", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Favorites, + }, // start Fred back in our favorite flow, then make it inactive, will be handled by catch-all - {nil, testdata.Org2, testdata.Org2Channel, testdata.Org2Contact, "start", "What is your favorite color?", models.MsgTypeFlow}, - {func() { - db.MustExec(`UPDATE flows_flow SET is_active = FALSE WHERE id = $1`, testdata.Org2Favorites.ID) - }, testdata.Org2, testdata.Org2Channel, testdata.Org2Contact, "red", "Hey, how are you?", models.MsgTypeFlow}, + { + org: testdata.Org2, + channel: testdata.Org2Channel, + contact: testdata.Org2Contact, + text: "start", + expectedReply: "What is your favorite color?", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Org2Favorites, + }, + { + preHook: func() { + db.MustExec(`UPDATE flows_flow SET is_active = FALSE WHERE id = $1`, testdata.Org2Favorites.ID) + }, + org: testdata.Org2, + channel: testdata.Org2Channel, + contact: testdata.Org2Contact, + text: "red", + expectedReply: "Hey, how are you?", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Org2SingleMessage, + }, // start Fred back in our favorites flow to test retries - {func() { - db.MustExec(`UPDATE flows_flow SET is_active = TRUE WHERE id = $1`, testdata.Org2Favorites.ID) - }, testdata.Org2, testdata.Org2Channel, testdata.Org2Contact, "start", "What is your favorite color?", models.MsgTypeFlow}, + { + preHook: func() { + db.MustExec(`UPDATE flows_flow SET is_active = TRUE WHERE id = $1`, testdata.Org2Favorites.ID) + }, + org: testdata.Org2, + channel: testdata.Org2Channel, + contact: testdata.Org2Contact, + text: "start", + expectedReply: "What is your favorite color?", + expectedType: models.MsgTypeFlow, + expectedFlow: testdata.Org2Favorites, + }, } makeMsgTask := func(org *testdata.Org, channel *testdata.Channel, contact *testdata.Contact, text string) *queue.Task { - event := &handler.MsgEvent{ + return &queue.Task{Type: handler.MsgEventType, OrgID: int(org.ID), Task: jsonx.MustMarshal(&handler.MsgEvent{ ContactID: contact.ID, OrgID: org.ID, ChannelID: channel.ID, @@ -111,18 +265,7 @@ func TestMsgEvents(t *testing.T) { URN: contact.URN, URNID: contact.URNID, Text: text, - } - - eventJSON, err := json.Marshal(event) - assert.NoError(t, err) - - task := &queue.Task{ - Type: handler.MsgEventType, - OrgID: int(org.ID), - Task: eventJSON, - } - - return task + })} } last := time.Now() @@ -133,14 +276,14 @@ func TestMsgEvents(t *testing.T) { // reset our dummy db message into an unhandled state db.MustExec(`UPDATE msgs_msg SET status = 'P', msg_type = NULL WHERE id = $1`, dbMsg.ID()) - // run our hook if we have one - if tc.Hook != nil { - tc.Hook() + // run our setup hook if we have one + if tc.preHook != nil { + tc.preHook() } - task := makeMsgTask(tc.Org, tc.Channel, tc.Contact, tc.Text) + task := makeMsgTask(tc.org, tc.channel, tc.contact, tc.text) - err := handler.QueueHandleTask(rc, tc.Contact.ID, task) + err := handler.QueueHandleTask(rc, tc.contact.ID, task) assert.NoError(t, err, "%d: error adding task", i) task, err = queue.PopNextTask(rc, queue.HandlerQueue) @@ -149,30 +292,38 @@ func TestMsgEvents(t *testing.T) { err = handler.HandleEvent(ctx, rt, task) assert.NoError(t, err, "%d: error when handling event", i) + var expectedFlowID interface{} + if tc.expectedFlow != nil { + expectedFlowID = int64(tc.expectedFlow.ID) + } + // check that message is marked as handled with expected type - testsuite.AssertQuery(t, db, `SELECT msg_type, status FROM msgs_msg WHERE id = $1`, dbMsg.ID()). - Columns(map[string]interface{}{"msg_type": string(tc.ExpectedType), "status": "H"}, "%d: msg state mismatch", i) + assertdb.Query(t, db, `SELECT status, msg_type, flow_id FROM msgs_msg WHERE id = $1`, dbMsg.ID()). + Columns(map[string]interface{}{"status": "H", "msg_type": string(tc.expectedType), "flow_id": expectedFlowID}, "%d: msg state mismatch", i) // if we are meant to have a reply, check it - if tc.ExpectedReply != "" { - testsuite.AssertQuery(t, db, `SELECT text FROM msgs_msg WHERE contact_id = $1 AND created_on > $2 ORDER BY id DESC LIMIT 1`, tc.Contact.ID, last). - Returns(tc.ExpectedReply, "%d: response mismatch", i) + if tc.expectedReply != "" { + assertdb.Query(t, db, `SELECT text FROM msgs_msg WHERE contact_id = $1 AND created_on > $2 ORDER BY id DESC LIMIT 1`, tc.contact.ID, last). + Returns(tc.expectedReply, "%d: response mismatch", i) } // check any open tickets for this contact where updated - numOpenTickets := len(openTickets[tc.Contact]) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticket WHERE contact_id = $1 AND status = 'O' AND last_activity_on > $2`, tc.Contact.ID, last). + numOpenTickets := len(openTickets[tc.contact]) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticket WHERE contact_id = $1 AND status = 'O' AND last_activity_on > $2`, tc.contact.ID, last). Returns(numOpenTickets, "%d: updated open ticket mismatch", i) // check any closed tickets are unchanged - numClosedTickets := len(closedTickets[tc.Contact]) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticket WHERE contact_id = $1 AND status = 'C' AND last_activity_on = '2021-01-01T00:00:00Z'`, tc.Contact.ID). + numClosedTickets := len(closedTickets[tc.contact]) + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticket WHERE contact_id = $1 AND status = 'C' AND last_activity_on = '2021-01-01T00:00:00Z'`, tc.contact.ID). Returns(numClosedTickets, "%d: unchanged closed ticket mismatch", i) last = time.Now() } // should have one remaining IVR task to handle for Bob + orgTasks := testsuite.CurrentOrgTasks(t, rp) + assert.Equal(t, 1, len(orgTasks[testdata.Org1.ID])) + task, err := queue.PopNextTask(rc, queue.BatchQueue) assert.NoError(t, err) assert.NotNil(t, task) @@ -185,10 +336,11 @@ func TestMsgEvents(t *testing.T) { }) // Fred's sessions should not have a timeout because courier will set them - testsuite.AssertQuery(t, db, `SELECT count(*) from flows_flowsession where contact_id = $1 and timeout_on IS NULL AND wait_started_on IS NOT NULL`, testdata.Org2Contact.ID).Returns(2) + assertdb.Query(t, db, `SELECT count(*) from flows_flowsession where contact_id = $1`, testdata.Org2Contact.ID).Returns(6) + assertdb.Query(t, db, `SELECT count(*) from flows_flowsession where contact_id = $1 and timeout_on IS NULL`, testdata.Org2Contact.ID).Returns(6) // force an error by marking our run for fred as complete (our session is still active so this will blow up) - db.MustExec(`UPDATE flows_flowrun SET is_active = FALSE WHERE contact_id = $1`, testdata.Org2Contact.ID) + db.MustExec(`UPDATE flows_flowrun SET status = 'C' WHERE contact_id = $1`, testdata.Org2Contact.ID) task = makeMsgTask(testdata.Org2, testdata.Org2Channel, testdata.Org2Contact, "red") handler.QueueHandleTask(rc, testdata.Org2Contact.ID, task) @@ -218,11 +370,11 @@ func TestMsgEvents(t *testing.T) { assert.NoError(t, err) // should get our catch all trigger - testsuite.AssertQuery(t, db, `SELECT text FROM msgs_msg WHERE contact_id = $1 AND direction = 'O' ORDER BY id DESC LIMIT 1`, testdata.Org2Contact.ID).Returns("Hey, how are you?") + assertdb.Query(t, db, `SELECT text FROM msgs_msg WHERE contact_id = $1 AND direction = 'O' ORDER BY id DESC LIMIT 1`, testdata.Org2Contact.ID).Returns("Hey, how are you?") previous := time.Now() // and should have failed previous session - testsuite.AssertQuery(t, db, `SELECT count(*) from flows_flowsession where contact_id = $1 and status = 'F' and current_flow_id = $2`, testdata.Org2Contact.ID, testdata.Org2Favorites.ID).Returns(2) + assertdb.Query(t, db, `SELECT count(*) from flows_flowsession where contact_id = $1 and status = 'F'`, testdata.Org2Contact.ID).Returns(2) // trigger should also not start a new session task = makeMsgTask(testdata.Org2, testdata.Org2Channel, testdata.Org2Contact, "start") @@ -231,7 +383,7 @@ func TestMsgEvents(t *testing.T) { err = handler.HandleEvent(ctx, rt, task) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND direction = 'O' AND created_on > $2`, testdata.Org2Contact.ID, previous).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND direction = 'O' AND created_on > $2`, testdata.Org2Contact.ID, previous).Returns(0) } func TestChannelEvents(t *testing.T) { @@ -292,7 +444,7 @@ func TestChannelEvents(t *testing.T) { // if we are meant to have a response if tc.Response != "" { - testsuite.AssertQuery(t, db, `SELECT text FROM msgs_msg WHERE contact_id = $1 AND contact_urn_id = $2 AND created_on > $3 ORDER BY id DESC LIMIT 1`, tc.ContactID, tc.URNID, start). + assertdb.Query(t, db, `SELECT text FROM msgs_msg WHERE contact_id = $1 AND contact_urn_id = $2 AND created_on > $3 ORDER BY id DESC LIMIT 1`, tc.ContactID, tc.URNID, start). Returns(tc.Response, "%d: response mismatch", i) } @@ -329,7 +481,7 @@ func TestTicketEvents(t *testing.T) { err = handler.HandleEvent(ctx, rt, task) require.NoError(t, err) - testsuite.AssertQuery(t, rt.DB, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND direction = 'O' AND text = 'What is your favorite color?'`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, rt.DB, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND direction = 'O' AND text = 'What is your favorite color?'`, testdata.Cathy.ID).Returns(1) } func TestStopEvent(t *testing.T) { @@ -340,7 +492,8 @@ func TestStopEvent(t *testing.T) { defer testsuite.Reset(testsuite.ResetAll) // schedule an event for cathy and george - db.MustExec(`INSERT INTO campaigns_eventfire(scheduled, contact_id, event_id) VALUES (NOW(), $1, $3), (NOW(), $2, $3);`, testdata.Cathy.ID, testdata.George.ID, testdata.RemindersEvent1.ID) + testdata.InsertEventFire(rt.DB, testdata.Cathy, testdata.RemindersEvent1, time.Now()) + testdata.InsertEventFire(rt.DB, testdata.George, testdata.RemindersEvent1, time.Now()) // and george to doctors group, cathy is already part of it db.MustExec(`INSERT INTO contacts_contactgroup_contacts(contactgroup_id, contact_id) VALUES($1, $2);`, testdata.DoctorsGroup.ID, testdata.George.ID) @@ -364,15 +517,15 @@ func TestStopEvent(t *testing.T) { assert.NoError(t, err, "error when handling event") // check that only george is in our group - testsuite.AssertQuery(t, db, `SELECT count(*) from contacts_contactgroup_contacts WHERE contactgroup_id = $1 AND contact_id = $2`, testdata.DoctorsGroup.ID, testdata.Cathy.ID).Returns(0) - testsuite.AssertQuery(t, db, `SELECT count(*) from contacts_contactgroup_contacts WHERE contactgroup_id = $1 AND contact_id = $2`, testdata.DoctorsGroup.ID, testdata.George.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) from contacts_contactgroup_contacts WHERE contactgroup_id = $1 AND contact_id = $2`, testdata.DoctorsGroup.ID, testdata.Cathy.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) from contacts_contactgroup_contacts WHERE contactgroup_id = $1 AND contact_id = $2`, testdata.DoctorsGroup.ID, testdata.George.ID).Returns(1) // that cathy is stopped - testsuite.AssertQuery(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'S'`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM contacts_contact WHERE id = $1 AND status = 'S'`, testdata.Cathy.ID).Returns(1) // and has no upcoming events - testsuite.AssertQuery(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1`, testdata.Cathy.ID).Returns(0) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1`, testdata.George.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1`, testdata.Cathy.ID).Returns(0) + assertdb.Query(t, db, `SELECT count(*) FROM campaigns_eventfire WHERE contact_id = $1`, testdata.George.ID).Returns(1) } func TestTimedEvents(t *testing.T) { @@ -486,7 +639,7 @@ func TestTimedEvents(t *testing.T) { assert.NoError(t, err, "%d: error when handling event", i) if tc.Response != "" { - testsuite.AssertQuery(t, db, `SELECT text FROM msgs_msg WHERE contact_id = $1 AND created_on > $2 ORDER BY id DESC LIMIT 1`, tc.Contact.ID, last). + assertdb.Query(t, db, `SELECT text FROM msgs_msg WHERE contact_id = $1 AND created_on > $2 ORDER BY id DESC LIMIT 1`, tc.Contact.ID, last). Returns(tc.Response, "%d: response: mismatch", i) } @@ -514,7 +667,7 @@ func TestTimedEvents(t *testing.T) { // set both to be active (this requires us to disable the path change trigger for a bit which asserts flows can't cross back into active status) db.MustExec(`ALTER TABLE flows_flowrun DISABLE TRIGGER temba_flowrun_path_change`) db.MustExec(`UPDATE flows_flowrun SET is_active = TRUE, status = 'W', expires_on = $2 WHERE id = $1`, runID, expiration) - db.MustExec(`UPDATE flows_flowsession SET status = 'W' WHERE id = $1`, sessionID) + db.MustExec(`UPDATE flows_flowsession SET status = 'W', wait_started_on = NOW(), wait_expires_on = NOW() WHERE id = $1`, sessionID) db.MustExec(`ALTER TABLE flows_flowrun ENABLE TRIGGER temba_flowrun_path_change`) // try to expire the run @@ -535,6 +688,6 @@ func TestTimedEvents(t *testing.T) { err = handler.HandleEvent(ctx, rt, task) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT count(*) from flows_flowrun WHERE is_active = FALSE AND status = 'F' AND id = $1`, runID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) from flows_flowsession WHERE status = 'F' AND id = $1`, sessionID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) from flows_flowrun WHERE is_active = FALSE AND status = 'F' AND id = $1`, runID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) from flows_flowsession WHERE status = 'F' AND id = $1`, sessionID).Returns(1) } diff --git a/core/tasks/handler/worker.go b/core/tasks/handler/worker.go index 00087ac83..afa664882 100644 --- a/core/tasks/handler/worker.go +++ b/core/tasks/handler/worker.go @@ -8,6 +8,7 @@ import ( "github.com/gomodule/redigo/redis" "github.com/jmoiron/sqlx" + "github.com/nyaruka/gocommon/dbutil" "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/goflow/excellent/types" "github.com/nyaruka/goflow/flows" @@ -21,7 +22,6 @@ import ( "github.com/nyaruka/mailroom/core/queue" "github.com/nyaruka/mailroom/core/runner" "github.com/nyaruka/mailroom/runtime" - "github.com/nyaruka/mailroom/utils/dbutil" "github.com/nyaruka/mailroom/utils/locker" "github.com/nyaruka/null" "github.com/pkg/errors" @@ -172,7 +172,8 @@ func handleContactEvent(ctx context.Context, rt *runtime.Runtime, task *queue.Ta }) if qerr := dbutil.AsQueryError(err); qerr != nil { - log.WithFields(qerr.Fields()) + query, params := qerr.Query() + log = log.WithFields(logrus.Fields{"sql": query, "sql_params": params}) } contactEvent.ErrorCount++ @@ -226,8 +227,8 @@ func handleTimedEvent(ctx context.Context, rt *runtime.Runtime, eventType string return errors.Wrapf(err, "error creating flow contact") } - // get the active session for this contact - session, err := models.ActiveSessionForContact(ctx, rt.DB, rt.SessionStorage, oa, models.FlowTypeMessaging, contact) + // look for a waiting session for this contact + session, err := models.FindWaitingSessionForContact(ctx, rt.DB, rt.SessionStorage, oa, models.FlowTypeMessaging, contact) if err != nil { return errors.Wrapf(err, "error loading active session for contact") } @@ -235,7 +236,7 @@ func handleTimedEvent(ctx context.Context, rt *runtime.Runtime, eventType string // if we didn't find a session or it is another session then this flow got interrupted and this is a race, fail it if session == nil || session.ID() != event.SessionID { log.Error("expiring run with mismatched session, session for run no longer active, failing runs and session") - err = models.ExitSessions(ctx, rt.DB, []models.SessionID{event.SessionID}, models.ExitFailed, time.Now()) + err = models.ExitSessions(ctx, rt.DB, []models.SessionID{event.SessionID}, models.SessionStatusFailed) if err != nil { return errors.Wrapf(err, "error failing expired runs for session that is no longer active") } @@ -266,13 +267,13 @@ func handleTimedEvent(ctx context.Context, rt *runtime.Runtime, eventType string resume = resumes.NewRunExpiration(oa.Env(), contact) case TimeoutEventType: - if session.TimeoutOn() == nil { + if session.WaitTimeoutOn() == nil { log.WithField("session_id", session.ID()).Info("ignoring session timeout, has no timeout set") return nil } // check that the timeout is the same - timeout := *session.TimeoutOn() + timeout := *session.WaitTimeoutOn() if !timeout.Equal(event.Time) { log.WithField("event_timeout", event.Time).WithField("session_timeout", timeout).Info("ignoring timeout, has been updated") return nil @@ -413,14 +414,14 @@ func HandleChannelEvent(ctx context.Context, rt *runtime.Runtime, eventType mode switch eventType { case models.NewConversationEventType, models.ReferralEventType, models.MOMissEventType: - flowTrigger = triggers.NewBuilder(oa.Env(), flow.FlowReference(), contact). + flowTrigger = triggers.NewBuilder(oa.Env(), flow.Reference(), contact). Channel(channel.ChannelReference(), triggers.ChannelEventType(eventType)). WithParams(params). Build() case models.MOCallEventType: urn := contacts[0].URNForID(event.URNID()) - flowTrigger = triggers.NewBuilder(oa.Env(), flow.FlowReference(), contact). + flowTrigger = triggers.NewBuilder(oa.Env(), flow.Reference(), contact). Channel(channel.ChannelReference(), triggers.ChannelEventTypeIncomingCall). WithConnection(urn). Build() @@ -498,7 +499,7 @@ func handleMsgEvent(ctx context.Context, rt *runtime.Runtime, event *MsgEvent) e // contact has been deleted, ignore this message but mark it as handled if len(contacts) == 0 { - err := models.UpdateMessage(ctx, rt.DB, event.MsgID, models.MsgStatusHandled, models.VisibilityArchived, models.MsgTypeInbox, topupID) + err := models.UpdateMessage(ctx, rt.DB, event.MsgID, models.MsgStatusHandled, models.VisibilityArchived, models.MsgTypeInbox, models.NilFlowID, topupID) if err != nil { return errors.Wrapf(err, "error updating message for deleted contact") } @@ -526,7 +527,7 @@ func handleMsgEvent(ctx context.Context, rt *runtime.Runtime, event *MsgEvent) e // if this channel is no longer active or this contact is blocked, ignore this message (mark it as handled) if channel == nil || modelContact.Status() == models.ContactStatusBlocked { - err := models.UpdateMessage(ctx, rt.DB, event.MsgID, models.MsgStatusHandled, models.VisibilityArchived, models.MsgTypeInbox, topupID) + err := models.UpdateMessage(ctx, rt.DB, event.MsgID, models.MsgStatusHandled, models.VisibilityArchived, models.MsgTypeInbox, models.NilFlowID, topupID) if err != nil { return errors.Wrapf(err, "error marking blocked or nil channel message as handled") } @@ -564,8 +565,8 @@ func handleMsgEvent(ctx context.Context, rt *runtime.Runtime, event *MsgEvent) e // find any matching triggers trigger := models.FindMatchingMsgTrigger(oa, contact, event.Text) - // get any active session for this contact - session, err := models.ActiveSessionForContact(ctx, rt.DB, rt.SessionStorage, oa, models.FlowTypeMessaging, contact) + // look for a waiting session for this contact + session, err := models.FindWaitingSessionForContact(ctx, rt.DB, rt.SessionStorage, oa, models.FlowTypeMessaging, contact) if err != nil { return errors.Wrapf(err, "error loading active session for contact") } @@ -577,7 +578,7 @@ func handleMsgEvent(ctx context.Context, rt *runtime.Runtime, event *MsgEvent) e // flow this session is in is gone, interrupt our session and reset it if err == models.ErrNotFound { - err = models.ExitSessions(ctx, rt.DB, []models.SessionID{session.ID()}, models.ExitFailed, time.Now()) + err = models.ExitSessions(ctx, rt.DB, []models.SessionID{session.ID()}, models.SessionStatusFailed) session = nil } @@ -598,14 +599,14 @@ func handleMsgEvent(ctx context.Context, rt *runtime.Runtime, event *MsgEvent) e } sessions[0].SetIncomingMsg(event.MsgID, event.MsgExternalID) - return markMsgHandled(ctx, tx, contact, msgIn, models.MsgTypeFlow, topupID, tickets) + return markMsgHandled(ctx, tx, contact, msgIn, flow, topupID, tickets) } // we found a trigger and their session is nil or doesn't ignore keywords if (trigger != nil && trigger.TriggerType() != models.CatchallTriggerType && (flow == nil || !flow.IgnoreTriggers())) || (trigger != nil && trigger.TriggerType() == models.CatchallTriggerType && (flow == nil)) { // load our flow - flow, err := oa.FlowByID(trigger.FlowID()) + flow, err = oa.FlowByID(trigger.FlowID()) if err != nil && err != models.ErrNotFound { return errors.Wrapf(err, "error loading flow for trigger") } @@ -615,7 +616,7 @@ func handleMsgEvent(ctx context.Context, rt *runtime.Runtime, event *MsgEvent) e // if this is an IVR flow, we need to trigger that start (which happens in a different queue) if flow.FlowType() == models.FlowTypeVoice { ivrMsgHook := func(ctx context.Context, tx *sqlx.Tx) error { - return markMsgHandled(ctx, tx, contact, msgIn, models.MsgTypeFlow, topupID, tickets) + return markMsgHandled(ctx, tx, contact, msgIn, flow, topupID, tickets) } err = runner.TriggerIVRFlow(ctx, rt, oa.OrgID(), flow.ID(), []models.ContactID{modelContact.ID()}, ivrMsgHook) if err != nil { @@ -625,7 +626,7 @@ func handleMsgEvent(ctx context.Context, rt *runtime.Runtime, event *MsgEvent) e } // otherwise build the trigger and start the flow directly - trigger := triggers.NewBuilder(oa.Env(), flow.FlowReference(), contact).Msg(msgIn).WithMatch(trigger.Match()).Build() + trigger := triggers.NewBuilder(oa.Env(), flow.Reference(), contact).Msg(msgIn).WithMatch(trigger.Match()).Build() _, err = runner.StartFlowForContacts(ctx, rt, oa, flow, []flows.Trigger{trigger}, flowMsgHook, true) if err != nil { return errors.Wrapf(err, "error starting flow for contact") @@ -734,7 +735,7 @@ func handleTicketEvent(ctx context.Context, rt *runtime.Runtime, event *models.T switch event.EventType() { case models.TicketEventTypeClosed: - flowTrigger = triggers.NewBuilder(oa.Env(), flow.FlowReference(), contact). + flowTrigger = triggers.NewBuilder(oa.Env(), flow.Reference(), contact). Ticket(ticket, triggers.TicketEventTypeClosed). Build() default: @@ -761,12 +762,19 @@ func handleAsInbox(ctx context.Context, rt *runtime.Runtime, oa *models.OrgAsset return errors.Wrap(err, "error handling inbox message events") } - return markMsgHandled(ctx, rt.DB, contact, msg, models.MsgTypeInbox, topupID, tickets) + return markMsgHandled(ctx, rt.DB, contact, msg, nil, topupID, tickets) } // utility to mark as message as handled and update any open contact tickets -func markMsgHandled(ctx context.Context, db models.Queryer, contact *flows.Contact, msg *flows.MsgIn, msgType models.MsgType, topupID models.TopupID, tickets []*models.Ticket) error { - err := models.UpdateMessage(ctx, db, msg.ID(), models.MsgStatusHandled, models.VisibilityVisible, msgType, topupID) +func markMsgHandled(ctx context.Context, db models.Queryer, contact *flows.Contact, msg *flows.MsgIn, flow *models.Flow, topupID models.TopupID, tickets []*models.Ticket) error { + msgType := models.MsgTypeInbox + flowID := models.NilFlowID + if flow != nil { + msgType = models.MsgTypeFlow + flowID = flow.ID() + } + + err := models.UpdateMessage(ctx, db, msg.ID(), models.MsgStatusHandled, models.VisibilityVisible, msgType, flowID, topupID) if err != nil { return errors.Wrapf(err, "error marking message as handled") } diff --git a/core/tasks/incidents/end_incidents.go b/core/tasks/incidents/end_incidents.go new file mode 100644 index 000000000..9cc37442b --- /dev/null +++ b/core/tasks/incidents/end_incidents.go @@ -0,0 +1,121 @@ +package incidents + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/gomodule/redigo/redis" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/mailroom" + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/runtime" + "github.com/nyaruka/mailroom/utils/cron" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +func init() { + mailroom.AddInitFunction(startEndCron) +} + +func startEndCron(rt *runtime.Runtime, wg *sync.WaitGroup, quit chan bool) error { + cron.Start(quit, rt, "end_incidents", time.Minute*3, false, + func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() + return EndIncidents(ctx, rt) + }, + ) + return nil +} + +// EndIncidents checks open incidents and end any that no longer apply +func EndIncidents(ctx context.Context, rt *runtime.Runtime) error { + incidents, err := models.GetOpenIncidents(ctx, rt.DB, []models.IncidentType{models.IncidentTypeWebhooksUnhealthy}) + if err != nil { + return errors.Wrap(err, "error fetching open incidents") + } + + for _, incident := range incidents { + if incident.Type == models.IncidentTypeWebhooksUnhealthy { + if err := checkWebhookIncident(ctx, rt, incident); err != nil { + return errors.Wrapf(err, "error checking webhook incident #%d", incident.ID) + } + } + } + + return nil +} + +func checkWebhookIncident(ctx context.Context, rt *runtime.Runtime, incident *models.Incident) error { + nodeUUIDs, err := getWebhookIncidentNodes(rt, incident) + + if err != nil { + return errors.Wrap(err, "error getting webhook nodes") + } + + healthyNodeUUIDs := make([]flows.NodeUUID, 0, len(nodeUUIDs)) + + for _, nodeUUID := range nodeUUIDs { + node := models.WebhookNode{UUID: flows.NodeUUID(nodeUUID)} + healthy, err := node.Healthy(rt) + if err != nil { + return errors.Wrap(err, "error getting health of webhook nodes") + } + + if healthy { + healthyNodeUUIDs = append(healthyNodeUUIDs, nodeUUID) + } + } + + if len(healthyNodeUUIDs) > 0 { + if err := removeWebhookIncidentNodes(rt, incident, healthyNodeUUIDs); err != nil { + return errors.Wrap(err, "error removing nodes from webhook incident") + } + } + + log := logrus.WithFields(logrus.Fields{"incident_id": incident.ID, "unhealthy": len(nodeUUIDs) - len(healthyNodeUUIDs), "healthy": len(healthyNodeUUIDs)}) + + // if all of the nodes are now healthy the incident has ended + if len(healthyNodeUUIDs) == len(nodeUUIDs) { + if err := incident.End(ctx, rt.DB); err != nil { + return errors.Wrap(err, "error ending incident") + } + log.Info("ended webhook incident") + } else { + log.Debug("checked webhook incident") + } + + return nil +} + +func getWebhookIncidentNodes(rt *runtime.Runtime, incident *models.Incident) ([]flows.NodeUUID, error) { + rc := rt.RP.Get() + defer rc.Close() + + nodesKey := fmt.Sprintf("incident:%d:nodes", incident.ID) + nodes, err := redis.Strings(rc.Do("SMEMBERS", nodesKey)) + if err != nil { + return nil, err + } + + nodeUUIDs := make([]flows.NodeUUID, len(nodes)) + for i := range nodes { + nodeUUIDs[i] = flows.NodeUUID(nodes[i]) + } + return nodeUUIDs, nil +} + +func removeWebhookIncidentNodes(rt *runtime.Runtime, incident *models.Incident, nodes []flows.NodeUUID) error { + rc := rt.RP.Get() + defer rc.Close() + + nodesKey := fmt.Sprintf("incident:%d:nodes", incident.ID) + _, err := rc.Do("SREM", redis.Args{}.Add(nodesKey).AddFlat(nodes)...) + if err != nil { + return err + } + return nil +} diff --git a/core/tasks/incidents/end_incidents_test.go b/core/tasks/incidents/end_incidents_test.go new file mode 100644 index 000000000..dfffb40f4 --- /dev/null +++ b/core/tasks/incidents/end_incidents_test.go @@ -0,0 +1,63 @@ +package incidents_test + +import ( + "fmt" + "net/http" + "testing" + "time" + + "github.com/nyaruka/gocommon/dates" + "github.com/nyaruka/gocommon/dbutil/assertdb" + "github.com/nyaruka/gocommon/httpx" + "github.com/nyaruka/goflow/flows" + "github.com/nyaruka/goflow/flows/events" + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/core/tasks/incidents" + "github.com/nyaruka/mailroom/testsuite" + "github.com/nyaruka/mailroom/testsuite/testdata" + "github.com/nyaruka/redisx/assertredis" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEndIncidents(t *testing.T) { + ctx, rt, db, rp := testsuite.Get() + + defer testsuite.Reset(testsuite.ResetData | testsuite.ResetRedis) + + oa1 := testdata.Org1.Load(rt) + oa2 := testdata.Org2.Load(rt) + + createWebhookEvents := func(count int, elapsed time.Duration) []*events.WebhookCalledEvent { + evts := make([]*events.WebhookCalledEvent, count) + for i := range evts { + req, _ := http.NewRequest("GET", "http://example.com", nil) + trace := &httpx.Trace{Request: req, StartTime: dates.Now(), EndTime: dates.Now().Add(elapsed)} + evts[i] = events.NewWebhookCalled(&flows.WebhookCall{Trace: trace}, flows.CallStatusSuccess, "") + } + return evts + } + + node1 := &models.WebhookNode{UUID: "3c703019-8c92-4d28-9be0-a926a934486b"} + node1.Record(rt, createWebhookEvents(10, time.Second*30)) + + // create incident for org 1 based on node which is still unhealthy + id1, err := models.IncidentWebhooksUnhealthy(ctx, db, rp, oa1, []flows.NodeUUID{"3c703019-8c92-4d28-9be0-a926a934486b"}) + require.NoError(t, err) + + node2 := &models.WebhookNode{UUID: "07d69080-475b-4395-aa96-ea6c28ea6cb6"} + node2.Record(rt, createWebhookEvents(10, time.Second*1)) + + // create incident for org 2 based on node which is now healthy + id2, err := models.IncidentWebhooksUnhealthy(ctx, db, rp, oa2, []flows.NodeUUID{"07d69080-475b-4395-aa96-ea6c28ea6cb6"}) + require.NoError(t, err) + + err = incidents.EndIncidents(ctx, rt) + assert.NoError(t, err) + + assertdb.Query(t, db, `SELECT count(*) FROM notifications_incident WHERE id = $1 AND ended_on IS NULL`, id1).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM notifications_incident WHERE id = $1 AND ended_on IS NOT NULL`, id2).Returns(1) + + assertredis.SMembers(t, rp, fmt.Sprintf("incident:%d:nodes", id1), []string{"3c703019-8c92-4d28-9be0-a926a934486b"}) + assertredis.SMembers(t, rp, fmt.Sprintf("incident:%d:nodes", id2), []string{}) // healthy node removed +} diff --git a/core/tasks/interrupts/interrupt_sessions.go b/core/tasks/interrupts/interrupt_sessions.go index dd25a4423..ffde674f4 100644 --- a/core/tasks/interrupts/interrupt_sessions.go +++ b/core/tasks/interrupts/interrupt_sessions.go @@ -8,7 +8,6 @@ import ( "github.com/nyaruka/mailroom/core/tasks" "github.com/nyaruka/mailroom/runtime" - "github.com/lib/pq" "github.com/pkg/errors" ) @@ -27,37 +26,6 @@ type InterruptSessionsTask struct { FlowIDs []models.FlowID `json:"flow_ids,omitempty"` } -const activeSessionIDsForChannelsSQL = ` -SELECT - fs.id -FROM - flows_flowsession fs - JOIN channels_channelconnection cc ON fs.connection_id = cc.id -WHERE - fs.status = 'W' AND - cc.channel_id = ANY($1); -` - -const activeSessionIDsForContactsSQL = ` -SELECT - id -FROM - flows_flowsession fs -WHERE - fs.status = 'W' AND - fs.contact_id = ANY($1); -` - -const activeSessionIDsForFlowsSQL = ` -SELECT - id -FROM - flows_flowsession fs -WHERE - fs.status = 'W' AND - fs.current_flow_id = ANY($1); -` - // Timeout is the maximum amount of time the task can run for func (t *InterruptSessionsTask) Timeout() time.Duration { return time.Hour @@ -66,62 +34,26 @@ func (t *InterruptSessionsTask) Timeout() time.Duration { func (t *InterruptSessionsTask) Perform(ctx context.Context, rt *runtime.Runtime, orgID models.OrgID) error { db := rt.DB - sessionIDs := make(map[models.SessionID]bool) - for _, sid := range t.SessionIDs { - sessionIDs[sid] = true - } - - // if we have ivr channel ids, explode those to session ids - if len(t.ChannelIDs) > 0 { - channelSessionIDs := make([]models.SessionID, 0, len(t.ChannelIDs)) - - err := db.SelectContext(ctx, &channelSessionIDs, activeSessionIDsForChannelsSQL, pq.Array(t.ChannelIDs)) - if err != nil { - return errors.Wrapf(err, "error selecting sessions for channels") - } - - for _, sid := range channelSessionIDs { - sessionIDs[sid] = true - } - } - - // if we have contact ids, explode those to session ids if len(t.ContactIDs) > 0 { - contactSessionIDs := make([]models.SessionID, 0, len(t.ContactIDs)) - - err := db.SelectContext(ctx, &contactSessionIDs, activeSessionIDsForContactsSQL, pq.Array(t.ContactIDs)) - if err != nil { - return errors.Wrapf(err, "error selecting sessions for contacts") + if err := models.InterruptSessionsForContacts(ctx, db, t.ContactIDs); err != nil { + return err } - - for _, sid := range contactSessionIDs { - sessionIDs[sid] = true + } + if len(t.ChannelIDs) > 0 { + if err := models.InterruptSessionsForChannels(ctx, db, t.ChannelIDs); err != nil { + return err } } - - // if we have flow ids, explode those to session ids if len(t.FlowIDs) > 0 { - flowSessionIDs := make([]models.SessionID, 0, len(t.FlowIDs)) - - err := db.SelectContext(ctx, &flowSessionIDs, activeSessionIDsForFlowsSQL, pq.Array(t.FlowIDs)) - if err != nil { - return errors.Wrapf(err, "error selecting sessions for flows") - } - - for _, sid := range flowSessionIDs { - sessionIDs[sid] = true + if err := models.InterruptSessionsForFlows(ctx, db, t.FlowIDs); err != nil { + return err } } - - uniqueSessionIDs := make([]models.SessionID, 0, len(sessionIDs)) - for id := range sessionIDs { - uniqueSessionIDs = append(uniqueSessionIDs, id) + if len(t.SessionIDs) > 0 { + if err := models.ExitSessions(ctx, db, t.SessionIDs, models.SessionStatusInterrupted); err != nil { + return errors.Wrapf(err, "error interrupting sessions") + } } - // interrupt all sessions and their associated runs - err := models.ExitSessions(ctx, db, uniqueSessionIDs, models.ExitInterrupted, time.Now()) - if err != nil { - return errors.Wrapf(err, "error interrupting sessions") - } return nil } diff --git a/core/tasks/interrupts/interrupt_sessions_test.go b/core/tasks/interrupts/interrupt_sessions_test.go index 079b931cc..6bdbb251d 100644 --- a/core/tasks/interrupts/interrupt_sessions_test.go +++ b/core/tasks/interrupts/interrupt_sessions_test.go @@ -3,7 +3,7 @@ package interrupts import ( "testing" - "github.com/nyaruka/gocommon/uuids" + "github.com/nyaruka/gocommon/dbutil/assertdb" _ "github.com/nyaruka/mailroom/core/handlers" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/testsuite" @@ -17,29 +17,11 @@ func TestInterrupts(t *testing.T) { defer testsuite.Reset(testsuite.ResetAll) - insertConnection := func(orgID models.OrgID, channelID models.ChannelID, contactID models.ContactID, urnID models.URNID) models.ConnectionID { - var connectionID models.ConnectionID - err := db.Get(&connectionID, - `INSERT INTO channels_channelconnection(created_on, modified_on, external_id, status, direction, connection_type, error_count, org_id, channel_id, contact_id, contact_urn_id) - VALUES(NOW(), NOW(), 'ext1', 'I', 'I', 'V', 0, $1, $2, $3, $4) RETURNING id`, - orgID, channelID, contactID, urnID, - ) - assert.NoError(t, err) - return connectionID - } - - insertSession := func(orgID models.OrgID, contactID models.ContactID, connectionID models.ConnectionID, currentFlowID models.FlowID) models.SessionID { - var sessionID models.SessionID - err := db.Get(&sessionID, - `INSERT INTO flows_flowsession(uuid, status, responded, created_on, org_id, contact_id, connection_id, current_flow_id, session_type) - VALUES($1, 'W', false, NOW(), $2, $3, $4, $5, 'M') RETURNING id`, - uuids.New(), orgID, contactID, connectionID, currentFlowID) - assert.NoError(t, err) - - // give session one active run too - db.MustExec(`INSERT INTO flows_flowrun(uuid, is_active, status, created_on, modified_on, responded, contact_id, flow_id, session_id, org_id) - VALUES($1, TRUE, 'W', now(), now(), FALSE, $2, $3, $4, 1);`, uuids.New(), contactID, currentFlowID, sessionID) + insertSession := func(org *testdata.Org, contact *testdata.Contact, flow *testdata.Flow, connectionID models.ConnectionID) models.SessionID { + sessionID := testdata.InsertFlowSession(db, org, contact, models.FlowTypeMessaging, models.SessionStatusWaiting, flow, connectionID, nil) + // give session one waiting run too + testdata.InsertFlowRun(db, org, sessionID, contact, flow, models.RunStatusWaiting, "", nil) return sessionID } @@ -80,18 +62,18 @@ func TestInterrupts(t *testing.T) { db.MustExec(`UPDATE flows_flowsession SET status='C', ended_on=NOW() WHERE status = 'W';`) // twilio connection - twilioConnectionID := insertConnection(testdata.Org1.ID, testdata.TwilioChannel.ID, testdata.Alexandria.ID, testdata.Alexandria.URNID) + twilioConnectionID := testdata.InsertConnection(db, testdata.Org1, testdata.TwilioChannel, testdata.Alexandria) sessionIDs := make([]models.SessionID, 5) // insert our dummy contact sessions - sessionIDs[0] = insertSession(testdata.Org1.ID, testdata.Cathy.ID, models.NilConnectionID, testdata.Favorites.ID) - sessionIDs[1] = insertSession(testdata.Org1.ID, testdata.George.ID, models.NilConnectionID, testdata.Favorites.ID) - sessionIDs[2] = insertSession(testdata.Org1.ID, testdata.Alexandria.ID, twilioConnectionID, testdata.Favorites.ID) - sessionIDs[3] = insertSession(testdata.Org1.ID, testdata.Bob.ID, models.NilConnectionID, testdata.PickANumber.ID) + sessionIDs[0] = insertSession(testdata.Org1, testdata.Cathy, testdata.Favorites, models.NilConnectionID) + sessionIDs[1] = insertSession(testdata.Org1, testdata.George, testdata.Favorites, models.NilConnectionID) + sessionIDs[2] = insertSession(testdata.Org1, testdata.Alexandria, testdata.Favorites, twilioConnectionID) + sessionIDs[3] = insertSession(testdata.Org1, testdata.Bob, testdata.PickANumber, models.NilConnectionID) // a session we always end explicitly - sessionIDs[4] = insertSession(testdata.Org1.ID, testdata.Bob.ID, models.NilConnectionID, testdata.Favorites.ID) + sessionIDs[4] = insertSession(testdata.Org1, testdata.Bob, testdata.Favorites, models.NilConnectionID) // create our task task := &InterruptSessionsTask{ @@ -113,7 +95,7 @@ func TestInterrupts(t *testing.T) { assert.Equal(t, tc.StatusesAfter[j], status, "%d: status mismatch for session #%d", i, j) // check for runs with a different status to the session - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowrun WHERE session_id = $1 AND status != $2`, sID, tc.StatusesAfter[j]). + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE session_id = $1 AND status != $2`, sID, tc.StatusesAfter[j]). Returns(0, "%d: unexpected un-interrupted runs for session #%d", i, j) } } diff --git a/core/tasks/ivr/cron.go b/core/tasks/ivr/cron.go index e1657f5f5..6b46959b4 100644 --- a/core/tasks/ivr/cron.go +++ b/core/tasks/ivr/cron.go @@ -27,19 +27,19 @@ func init() { // StartIVRCron starts our cron job of retrying errored calls func StartIVRCron(rt *runtime.Runtime, wg *sync.WaitGroup, quit chan bool) error { - cron.StartCron(quit, rt.RP, retryIVRLock, time.Minute, - func(lockName string, lockValue string) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10) + cron.Start(quit, rt, retryIVRLock, time.Minute, false, + func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() - return retryCalls(ctx, rt, retryIVRLock, lockValue) + return retryCalls(ctx, rt) }, ) - cron.StartCron(quit, rt.RP, expireIVRLock, time.Minute, - func(lockName string, lockValue string) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10) + cron.Start(quit, rt, expireIVRLock, time.Minute, false, + func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() - return expireCalls(ctx, rt, expireIVRLock, lockValue) + return expireCalls(ctx, rt) }, ) @@ -47,8 +47,8 @@ func StartIVRCron(rt *runtime.Runtime, wg *sync.WaitGroup, quit chan bool) error } // retryCalls looks for calls that need to be retried and retries them -func retryCalls(ctx context.Context, rt *runtime.Runtime, lockName string, lockValue string) error { - log := logrus.WithField("comp", "ivr_cron_retryer").WithField("lock", lockValue) +func retryCalls(ctx context.Context, rt *runtime.Runtime) error { + log := logrus.WithField("comp", "ivr_cron_retryer") start := time.Now() // find all calls that need restarting @@ -114,8 +114,8 @@ func retryCalls(ctx context.Context, rt *runtime.Runtime, lockName string, lockV } // expireCalls looks for calls that should be expired and ends them -func expireCalls(ctx context.Context, rt *runtime.Runtime, lockName string, lockValue string) error { - log := logrus.WithField("comp", "ivr_cron_expirer").WithField("lock", lockValue) +func expireCalls(ctx context.Context, rt *runtime.Runtime) error { + log := logrus.WithField("comp", "ivr_cron_expirer") start := time.Now() ctx, cancel := context.WithTimeout(ctx, time.Minute*10) @@ -128,7 +128,6 @@ func expireCalls(ctx context.Context, rt *runtime.Runtime, lockName string, lock } defer rows.Close() - expiredRuns := make([]models.FlowRunID, 0, 100) expiredSessions := make([]models.SessionID, 0, 100) for rows.Next() { @@ -138,8 +137,7 @@ func expireCalls(ctx context.Context, rt *runtime.Runtime, lockName string, lock return errors.Wrapf(err, "error scanning expired run") } - // add the run and session to those we need to expire - expiredRuns = append(expiredRuns, exp.RunID) + // add the session to those we need to expire expiredSessions = append(expiredSessions, exp.SessionID) // load our connection @@ -157,12 +155,12 @@ func expireCalls(ctx context.Context, rt *runtime.Runtime, lockName string, lock } // now expire our runs and sessions - if len(expiredRuns) > 0 { - err := models.ExpireRunsAndSessions(ctx, rt.DB, expiredRuns, expiredSessions) + if len(expiredSessions) > 0 { + err := models.ExitSessions(ctx, rt.DB, expiredSessions, models.SessionStatusExpired) if err != nil { - log.WithError(err).Error("error expiring runs and sessions for expired calls") + log.WithError(err).Error("error expiring sessions for expired calls") } - log.WithField("count", len(expiredRuns)).WithField("elapsed", time.Since(start)).Info("expired and hung up on channel connections") + log.WithField("count", len(expiredSessions)).WithField("elapsed", time.Since(start)).Info("expired and hung up on channel connections") } return nil diff --git a/core/tasks/ivr/cron_test.go b/core/tasks/ivr/cron_test.go index 85c2de43a..f4abfce95 100644 --- a/core/tasks/ivr/cron_test.go +++ b/core/tasks/ivr/cron_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "testing" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/mailroom/core/ivr" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/core/queue" @@ -46,18 +47,18 @@ func TestRetries(t *testing.T) { client.callID = ivr.CallID("call1") err = HandleFlowStartBatch(ctx, rt, batch) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, + assertdb.Query(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, testdata.Cathy.ID, models.ConnectionStatusWired, "call1").Returns(1) // change our call to be errored instead of wired db.MustExec(`UPDATE channels_channelconnection SET status = 'E', next_attempt = NOW() WHERE external_id = 'call1';`) // fire our retries - err = retryCalls(ctx, rt, "retry_test", "retry_test") + err = retryCalls(ctx, rt) assert.NoError(t, err) // should now be in wired state - testsuite.AssertQuery(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, + assertdb.Query(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, testdata.Cathy.ID, models.ConnectionStatusWired, "call1").Returns(1) // back to retry and make the channel inactive @@ -65,10 +66,10 @@ func TestRetries(t *testing.T) { db.MustExec(`UPDATE channels_channel SET is_active = FALSE WHERE id = $1`, testdata.TwilioChannel.ID) models.FlushCache() - err = retryCalls(ctx, rt, "retry_test", "retry_test") + err = retryCalls(ctx, rt) assert.NoError(t, err) // this time should be failed - testsuite.AssertQuery(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, + assertdb.Query(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, testdata.Cathy.ID, models.ConnectionStatusFailed, "call1").Returns(1) } diff --git a/core/tasks/ivr/worker_test.go b/core/tasks/ivr/worker_test.go index 8903305f8..23eac0582 100644 --- a/core/tasks/ivr/worker_test.go +++ b/core/tasks/ivr/worker_test.go @@ -6,6 +6,7 @@ import ( "net/http" "testing" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/httpx" "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/mailroom/core/ivr" @@ -51,20 +52,20 @@ func TestIVR(t *testing.T) { client.callError = errors.Errorf("unable to create call") err = HandleFlowStartBatch(ctx, rt, batch) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2`, testdata.Cathy.ID, models.ConnectionStatusFailed).Returns(1) + assertdb.Query(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2`, testdata.Cathy.ID, models.ConnectionStatusFailed).Returns(1) client.callError = nil client.callID = ivr.CallID("call1") err = HandleFlowStartBatch(ctx, rt, batch) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, testdata.Cathy.ID, models.ConnectionStatusWired, "call1").Returns(1) + assertdb.Query(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, testdata.Cathy.ID, models.ConnectionStatusWired, "call1").Returns(1) // trying again should put us in a throttled state (queued) client.callError = nil client.callID = ivr.CallID("call1") err = HandleFlowStartBatch(ctx, rt, batch) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND next_attempt IS NOT NULL;`, testdata.Cathy.ID, models.ConnectionStatusQueued).Returns(1) + assertdb.Query(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND next_attempt IS NOT NULL;`, testdata.Cathy.ID, models.ConnectionStatusQueued).Returns(1) } var client = &MockProvider{} diff --git a/core/tasks/msgs/retries.go b/core/tasks/msgs/retries.go new file mode 100644 index 000000000..660963519 --- /dev/null +++ b/core/tasks/msgs/retries.go @@ -0,0 +1,61 @@ +package msgs + +import ( + "context" + "sync" + "time" + + "github.com/nyaruka/mailroom" + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/core/msgio" + "github.com/nyaruka/mailroom/runtime" + "github.com/nyaruka/mailroom/utils/cron" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +const ( + retryMessagesLock = "retry_errored_messages" +) + +func init() { + mailroom.AddInitFunction(startCrons) +} + +func startCrons(rt *runtime.Runtime, wg *sync.WaitGroup, quit chan bool) error { + cron.Start(quit, rt, retryMessagesLock, time.Second*60, false, + func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() + return RetryErroredMessages(ctx, rt) + }, + ) + + return nil +} + +func RetryErroredMessages(ctx context.Context, rt *runtime.Runtime) error { + rc := rt.RP.Get() + defer rc.Close() + + start := time.Now() + + msgs, err := models.GetMessagesForRetry(ctx, rt.DB) + if err != nil { + return errors.Wrap(err, "error fetching errored messages to retry") + } + if len(msgs) == 0 { + return nil // nothing to retry + } + + err = models.MarkMessagesQueued(ctx, rt.DB, msgs) + if err != nil { + return errors.Wrap(err, "error marking messages as queued") + } + + msgio.SendMessages(ctx, rt, rt.DB, nil, msgs) + + logrus.WithField("count", len(msgs)).WithField("elapsed", time.Since(start)).Info("retried errored messages") + + return nil +} diff --git a/core/tasks/msgs/retries_test.go b/core/tasks/msgs/retries_test.go new file mode 100644 index 000000000..5f4b3b9c5 --- /dev/null +++ b/core/tasks/msgs/retries_test.go @@ -0,0 +1,55 @@ +package msgs_test + +import ( + "testing" + "time" + + "github.com/nyaruka/gocommon/dbutil/assertdb" + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/core/tasks/msgs" + "github.com/nyaruka/mailroom/testsuite" + "github.com/nyaruka/mailroom/testsuite/testdata" + "github.com/stretchr/testify/require" +) + +func TestRetryErroredMessages(t *testing.T) { + ctx, rt, db, rp := testsuite.Get() + rc := rp.Get() + defer rc.Close() + + defer testsuite.Reset(testsuite.ResetData | testsuite.ResetRedis) + + // nothing to retry + err := msgs.RetryErroredMessages(ctx, rt) + require.NoError(t, err) + + testsuite.AssertCourierQueues(t, map[string][]int{}) + + // a non-errored outgoing message (should be ignored) + testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "Hi", nil, models.MsgStatusDelivered, false) + + // an errored message with a next-attempt in the future (should be ignored) + testdata.InsertErroredOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "Hi", 2, time.Now().Add(time.Hour), false) + + // errored messages with a next-attempt in the past + testdata.InsertErroredOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "Hi", 1, time.Now().Add(-time.Hour), false) + testdata.InsertErroredOutgoingMsg(db, testdata.Org1, testdata.VonageChannel, testdata.Bob, "Hi", 2, time.Now().Add(-time.Minute), false) + testdata.InsertErroredOutgoingMsg(db, testdata.Org1, testdata.VonageChannel, testdata.Bob, "Hi", 2, time.Now().Add(-time.Minute), false) + testdata.InsertErroredOutgoingMsg(db, testdata.Org1, testdata.VonageChannel, testdata.Bob, "Hi", 2, time.Now().Add(-time.Minute), true) // high priority + + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'E'`).Returns(5) + + // try again... + err = msgs.RetryErroredMessages(ctx, rt) + require.NoError(t, err) + + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'D'`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'E'`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE status = 'Q'`).Returns(4) + + testsuite.AssertCourierQueues(t, map[string][]int{ + "msgs:74729f45-7f29-4868-9dc4-90e491e3c7d8|10/0": {1}, // twilio, bulk priority + "msgs:19012bfd-3ce3-4cae-9bb9-76cf92c73d49|10/0": {2}, // vonage, bulk priority + "msgs:19012bfd-3ce3-4cae-9bb9-76cf92c73d49|10/1": {1}, // vonage, high priority + }) +} diff --git a/core/tasks/broadcasts/worker.go b/core/tasks/msgs/send_broadcast.go similarity index 99% rename from core/tasks/broadcasts/worker.go rename to core/tasks/msgs/send_broadcast.go index 5f9bcb089..529b3ccd2 100644 --- a/core/tasks/broadcasts/worker.go +++ b/core/tasks/msgs/send_broadcast.go @@ -1,4 +1,4 @@ -package broadcasts +package msgs import ( "context" diff --git a/core/tasks/broadcasts/worker_test.go b/core/tasks/msgs/send_broadcast_test.go similarity index 89% rename from core/tasks/broadcasts/worker_test.go rename to core/tasks/msgs/send_broadcast_test.go index 7283a27bb..448e423f4 100644 --- a/core/tasks/broadcasts/worker_test.go +++ b/core/tasks/msgs/send_broadcast_test.go @@ -1,10 +1,11 @@ -package broadcasts +package msgs_test import ( "encoding/json" "testing" "time" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/envs" @@ -13,6 +14,7 @@ import ( _ "github.com/nyaruka/mailroom/core/handlers" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/core/queue" + "github.com/nyaruka/mailroom/core/tasks/msgs" "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" @@ -84,7 +86,7 @@ func TestBroadcastEvents(t *testing.T) { bcast, err := models.NewBroadcastFromEvent(ctx, db, oa, event) assert.NoError(t, err) - err = CreateBroadcastBatches(ctx, rt, bcast) + err = msgs.CreateBroadcastBatches(ctx, rt, bcast) assert.NoError(t, err) // pop all our tasks and execute them @@ -103,7 +105,7 @@ func TestBroadcastEvents(t *testing.T) { err = json.Unmarshal(task.Task, batch) assert.NoError(t, err) - err = SendBroadcastBatch(ctx, rt, batch) + err = msgs.SendBroadcastBatch(ctx, rt, batch) assert.NoError(t, err) } @@ -111,7 +113,7 @@ func TestBroadcastEvents(t *testing.T) { assert.Equal(t, tc.BatchCount, count, "%d: unexpected batch count", i) // assert our count of total msgs created - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE org_id = 1 AND created_on > $1 AND topup_id IS NOT NULL AND text = $2`, lastNow, tc.MsgText). + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE org_id = 1 AND created_on > $1 AND topup_id IS NOT NULL AND text = $2`, lastNow, tc.MsgText). Returns(tc.MsgCount, "%d: unexpected msg count", i) lastNow = time.Now() @@ -230,7 +232,7 @@ func TestBroadcastTask(t *testing.T) { for i, tc := range tcs { // handle our start task bcast := models.NewBroadcast(oa.OrgID(), tc.BroadcastID, tc.Translations, tc.TemplateState, tc.BaseLanguage, tc.URNs, tc.ContactIDs, tc.GroupIDs, tc.TicketID) - err = CreateBroadcastBatches(ctx, rt, bcast) + err = msgs.CreateBroadcastBatches(ctx, rt, bcast) assert.NoError(t, err) // pop all our tasks and execute them @@ -249,7 +251,7 @@ func TestBroadcastTask(t *testing.T) { err = json.Unmarshal(task.Task, batch) assert.NoError(t, err) - err = SendBroadcastBatch(ctx, rt, batch) + err = msgs.SendBroadcastBatch(ctx, rt, batch) assert.NoError(t, err) } @@ -257,18 +259,18 @@ func TestBroadcastTask(t *testing.T) { assert.Equal(t, tc.BatchCount, count, "%d: unexpected batch count", i) // assert our count of total msgs created - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE org_id = 1 AND created_on > $1 AND topup_id IS NOT NULL AND text = $2`, lastNow, tc.MsgText). + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE org_id = 1 AND created_on > $1 AND topup_id IS NOT NULL AND text = $2`, lastNow, tc.MsgText). Returns(tc.MsgCount, "%d: unexpected msg count", i) // make sure our broadcast is marked as sent if tc.BroadcastID != models.NilBroadcastID { - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_broadcast WHERE id = $1 AND status = 'S'`, tc.BroadcastID). + assertdb.Query(t, db, `SELECT count(*) FROM msgs_broadcast WHERE id = $1 AND status = 'S'`, tc.BroadcastID). Returns(1, "%d: broadcast not marked as sent", i) } // if we had a ticket, make sure its last_activity_on was updated if tc.TicketID != models.NilTicketID { - testsuite.AssertQuery(t, db, `SELECT count(*) FROM tickets_ticket WHERE id = $1 AND last_activity_on > $2`, tc.TicketID, modelTicket.LastActivityOn()). + assertdb.Query(t, db, `SELECT count(*) FROM tickets_ticket WHERE id = $1 AND last_activity_on > $2`, tc.TicketID, modelTicket.LastActivityOn()). Returns(1, "%d: ticket last_activity_on not updated", i) } diff --git a/core/tasks/schedules/cron.go b/core/tasks/schedules/cron.go index ec8d15dbc..262c3999e 100644 --- a/core/tasks/schedules/cron.go +++ b/core/tasks/schedules/cron.go @@ -24,22 +24,22 @@ func init() { // StartCheckSchedules starts our cron job of firing schedules every minute func StartCheckSchedules(rt *runtime.Runtime, wg *sync.WaitGroup, quit chan bool) error { - cron.StartCron(quit, rt.RP, scheduleLock, time.Minute*1, - func(lockName string, lockValue string) error { + cron.Start(quit, rt, scheduleLock, time.Minute*1, false, + func() error { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() // we sleep 1 second since we fire right on the minute and want to make sure to fire // things that are schedules right at the minute as well (and DB time may be slightly drifted) time.Sleep(time.Second * 1) - return checkSchedules(ctx, rt, lockName, lockValue) + return checkSchedules(ctx, rt) }, ) return nil } // checkSchedules looks up any expired schedules and fires them, setting the next fire as needed -func checkSchedules(ctx context.Context, rt *runtime.Runtime, lockName string, lockValue string) error { - log := logrus.WithField("comp", "schedules_cron").WithField("lock", lockValue) +func checkSchedules(ctx context.Context, rt *runtime.Runtime) error { + log := logrus.WithField("comp", "schedules_cron") start := time.Now() rc := rt.RP.Get() diff --git a/core/tasks/schedules/cron_test.go b/core/tasks/schedules/cron_test.go index 575fdc6c6..893a35d1e 100644 --- a/core/tasks/schedules/cron_test.go +++ b/core/tasks/schedules/cron_test.go @@ -3,6 +3,7 @@ package schedules import ( "testing" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/goflow/envs" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/core/queue" @@ -70,27 +71,27 @@ func TestCheckSchedules(t *testing.T) { assert.NoError(t, err) // run our task - err = checkSchedules(ctx, rt, "lock", "lock") + err = checkSchedules(ctx, rt) assert.NoError(t, err) // should have one flow start added to our DB ready to go - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowstart WHERE flow_id = $1 AND start_type = 'T' AND status = 'P'`, testdata.Favorites.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowstart WHERE flow_id = $1 AND start_type = 'T' AND status = 'P'`, testdata.Favorites.ID).Returns(1) // with the right count of groups and contacts - testsuite.AssertQuery(t, db, `SELECT count(*) from flows_flowstart_contacts WHERE flowstart_id = 1`).Returns(2) - testsuite.AssertQuery(t, db, `SELECT count(*) from flows_flowstart_groups WHERE flowstart_id = 1`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) from flows_flowstart_contacts WHERE flowstart_id = 1`).Returns(2) + assertdb.Query(t, db, `SELECT count(*) from flows_flowstart_groups WHERE flowstart_id = 1`).Returns(1) // and one broadcast as well - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_broadcast WHERE org_id = $1 AND parent_id = $2 + assertdb.Query(t, db, `SELECT count(*) FROM msgs_broadcast WHERE org_id = $1 AND parent_id = $2 AND text = hstore(ARRAY['eng','Test message', 'fra', 'Un Message']) AND status = 'Q' AND base_language = 'eng'`, testdata.Org1.ID, b1).Returns(1) // with the right count of groups, contacts, urns - testsuite.AssertQuery(t, db, `SELECT count(*) from msgs_broadcast_urns WHERE broadcast_id = 2`).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) from msgs_broadcast_contacts WHERE broadcast_id = 2`).Returns(2) - testsuite.AssertQuery(t, db, `SELECT count(*) from msgs_broadcast_groups WHERE broadcast_id = 2`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) from msgs_broadcast_urns WHERE broadcast_id = 2`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) from msgs_broadcast_contacts WHERE broadcast_id = 2`).Returns(2) + assertdb.Query(t, db, `SELECT count(*) from msgs_broadcast_groups WHERE broadcast_id = 2`).Returns(1) // we shouldn't have any pending schedules since there were all one time fires, but all should have last fire - testsuite.AssertQuery(t, db, `SELECT count(*) FROM schedules_schedule WHERE next_fire IS NULL and last_fire < NOW();`).Returns(3) + assertdb.Query(t, db, `SELECT count(*) FROM schedules_schedule WHERE next_fire IS NULL and last_fire < NOW();`).Returns(3) // check the tasks created task, err := queue.PopNextTask(rc, queue.BatchQueue) diff --git a/core/tasks/starts/worker_test.go b/core/tasks/starts/worker_test.go index 1009bf518..69be12388 100644 --- a/core/tasks/starts/worker_test.go +++ b/core/tasks/starts/worker_test.go @@ -5,6 +5,7 @@ import ( "fmt" "testing" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/uuids" _ "github.com/nyaruka/mailroom/core/handlers" "github.com/nyaruka/mailroom/core/models" @@ -305,20 +306,20 @@ func TestStarts(t *testing.T) { assert.Equal(t, tc.expectedBatchCount, count, "unexpected batch count in '%s'", tc.label) // assert our count of total flow runs created - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowrun WHERE flow_id = $1 AND start_id = $2`, tc.flowID, start.ID()).Returns(tc.expectedTotalCount, "unexpected total run count in '%s'", tc.label) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE flow_id = $1 AND start_id = $2`, tc.flowID, start.ID()).Returns(tc.expectedTotalCount, "unexpected total run count in '%s'", tc.label) // assert final status - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowstart where status = $2 AND id = $1`, start.ID(), tc.expectedStatus).Returns(1, "status mismatch in '%s'", tc.label) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowstart where status = $2 AND id = $1`, start.ID(), tc.expectedStatus).Returns(1, "status mismatch in '%s'", tc.label) // assert final contact count if tc.expectedStatus != models.StartStatusFailed { - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowstart where contact_count = $2 AND id = $1`, + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowstart where contact_count = $2 AND id = $1`, []interface{}{start.ID(), tc.expectedContactCount}, 1, "contact count mismatch in '%s'", tc.label) } // assert count of active runs by flow for flowID, activeRuns := range tc.expectedActiveRuns { - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowrun WHERE status = 'W' AND flow_id = $1`, flowID).Returns(activeRuns, "active runs mismatch for flow #%d in '%s'", flowID, tc.label) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE status = 'W' AND flow_id = $1`, flowID).Returns(activeRuns, "active runs mismatch for flow #%d in '%s'", flowID, tc.label) } } } diff --git a/core/tasks/stats/cron.go b/core/tasks/stats/cron.go deleted file mode 100644 index 0b92bdeaf..000000000 --- a/core/tasks/stats/cron.go +++ /dev/null @@ -1,87 +0,0 @@ -package stats - -import ( - "context" - "sync" - "time" - - "github.com/nyaruka/librato" - "github.com/nyaruka/mailroom" - "github.com/nyaruka/mailroom/core/queue" - "github.com/nyaruka/mailroom/runtime" - "github.com/nyaruka/mailroom/utils/cron" - "github.com/sirupsen/logrus" -) - -const ( - expirationLock = "stats" -) - -func init() { - mailroom.AddInitFunction(StartStatsCron) -} - -// StartStatsCron starts our cron job of posting stats every minute -func StartStatsCron(rt *runtime.Runtime, wg *sync.WaitGroup, quit chan bool) error { - cron.StartCron(quit, rt.RP, expirationLock, time.Second*60, - func(lockName string, lockValue string) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - defer cancel() - return dumpStats(ctx, rt, lockName, lockValue) - }, - ) - return nil -} - -var ( - waitDuration time.Duration - waitCount int64 -) - -// dumpStats calculates a bunch of stats every minute and both logs them and posts them to librato -func dumpStats(ctx context.Context, rt *runtime.Runtime, lockName string, lockValue string) error { - // We wait 15 seconds since we fire at the top of the minute, the same as expirations. - // That way any metrics related to the size of our queue are a bit more accurate (all expirations can - // usually be handled in 15 seconds). Something more complicated would take into account the age of - // the items in our queues. - time.Sleep(time.Second * 15) - - // get our DB status - stats := rt.DB.Stats() - - rc := rt.RP.Get() - defer rc.Close() - - // calculate size of batch queue - batchSize, err := queue.Size(rc, queue.BatchQueue) - if err != nil { - logrus.WithError(err).Error("error calculating batch queue size") - } - - // and size of handler queue - handlerSize, err := queue.Size(rc, queue.HandlerQueue) - if err != nil { - logrus.WithError(err).Error("error calculating handler queue size") - } - - logrus.WithFields(logrus.Fields{ - "db_idle": stats.Idle, - "db_busy": stats.InUse, - "db_waiting": stats.WaitCount - waitCount, - "db_wait": stats.WaitDuration - waitDuration, - "batch_size": batchSize, - "handler_size": handlerSize, - }).Info("current stats") - - librato.Gauge("mr.handler_queue", float64(handlerSize)) - librato.Gauge("mr.batch_queue", float64(batchSize)) - librato.Gauge("mr.db_busy", float64(stats.InUse)) - librato.Gauge("mr.db_idle", float64(stats.Idle)) - librato.Gauge("mr.db_waiting", float64(stats.WaitCount-waitCount)) - librato.Gauge("mr.db_wait_ms", float64((stats.WaitDuration-waitDuration)/time.Millisecond)) - - waitCount = stats.WaitCount - waitDuration = stats.WaitDuration - - return nil -} diff --git a/core/tasks/timeouts/cron.go b/core/tasks/timeouts/cron.go index 9a91a9564..c43caf92f 100644 --- a/core/tasks/timeouts/cron.go +++ b/core/tasks/timeouts/cron.go @@ -11,7 +11,7 @@ import ( "github.com/nyaruka/mailroom/core/tasks/handler" "github.com/nyaruka/mailroom/runtime" "github.com/nyaruka/mailroom/utils/cron" - "github.com/nyaruka/mailroom/utils/marker" + "github.com/nyaruka/redisx" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -19,20 +19,21 @@ import ( const ( timeoutLock = "sessions_timeouts" - markerGroup = "session_timeouts" ) +var marker = redisx.NewIntervalSet("session_timeouts", time.Hour*24, 2) + func init() { mailroom.AddInitFunction(StartTimeoutCron) } // StartTimeoutCron starts our cron job of continuing timed out sessions every minute func StartTimeoutCron(rt *runtime.Runtime, wg *sync.WaitGroup, quit chan bool) error { - cron.StartCron(quit, rt.RP, timeoutLock, time.Second*time.Duration(rt.Config.TimeoutTime), - func(lockName string, lockValue string) error { + cron.Start(quit, rt, timeoutLock, time.Second*60, false, + func() error { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() - return timeoutSessions(ctx, rt, lockName, lockValue) + return timeoutSessions(ctx, rt) }, ) return nil @@ -40,8 +41,8 @@ func StartTimeoutCron(rt *runtime.Runtime, wg *sync.WaitGroup, quit chan bool) e // timeoutRuns looks for any runs that have timed out and schedules for them to continue // TODO: extend lock -func timeoutSessions(ctx context.Context, rt *runtime.Runtime, lockName string, lockValue string) error { - log := logrus.WithField("comp", "timeout").WithField("lock", lockValue) +func timeoutSessions(ctx context.Context, rt *runtime.Runtime) error { + log := logrus.WithField("comp", "timeout") start := time.Now() // find all sessions that need to be expired (we exclude IVR runs) @@ -65,7 +66,7 @@ func timeoutSessions(ctx context.Context, rt *runtime.Runtime, lockName string, // check whether we've already queued this taskID := fmt.Sprintf("%d:%s", timeout.SessionID, timeout.TimeoutOn.Format(time.RFC3339)) - queued, err := marker.HasTask(rc, markerGroup, taskID) + queued, err := marker.Contains(rc, taskID) if err != nil { return errors.Wrapf(err, "error checking whether task is queued") } @@ -83,7 +84,7 @@ func timeoutSessions(ctx context.Context, rt *runtime.Runtime, lockName string, } // and mark it as queued - err = marker.AddTask(rc, markerGroup, taskID) + err = marker.Add(rc, taskID) if err != nil { return errors.Wrapf(err, "error marking timeout task as queued") } diff --git a/core/tasks/timeouts/cron_test.go b/core/tasks/timeouts/cron_test.go index 34591283d..6d324e380 100644 --- a/core/tasks/timeouts/cron_test.go +++ b/core/tasks/timeouts/cron_test.go @@ -11,7 +11,6 @@ import ( "github.com/nyaruka/mailroom/core/tasks/handler" "github.com/nyaruka/mailroom/testsuite" "github.com/nyaruka/mailroom/testsuite/testdata" - "github.com/nyaruka/mailroom/utils/marker" "github.com/stretchr/testify/assert" ) @@ -21,21 +20,18 @@ func TestTimeouts(t *testing.T) { rc := rp.Get() defer rc.Close() - defer testsuite.Reset(testsuite.ResetAll) - - err := marker.ClearTasks(rc, timeoutLock) - assert.NoError(t, err) + defer testsuite.Reset(testsuite.ResetData | testsuite.ResetRedis) // need to create a session that has an expired timeout s1TimeoutOn := time.Now() - testdata.InsertFlowSession(db, testdata.Org1, testdata.Cathy, models.SessionStatusWaiting, &s1TimeoutOn) + testdata.InsertFlowSession(db, testdata.Org1, testdata.Cathy, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, models.NilConnectionID, &s1TimeoutOn) s2TimeoutOn := time.Now().Add(time.Hour * 24) - testdata.InsertFlowSession(db, testdata.Org1, testdata.George, models.SessionStatusWaiting, &s2TimeoutOn) + testdata.InsertFlowSession(db, testdata.Org1, testdata.George, models.FlowTypeMessaging, models.SessionStatusWaiting, testdata.Favorites, models.NilConnectionID, &s2TimeoutOn) time.Sleep(10 * time.Millisecond) // schedule our timeouts - err = timeoutSessions(ctx, rt, timeoutLock, "foo") + err := timeoutSessions(ctx, rt) assert.NoError(t, err) // should have created one task diff --git a/go.mod b/go.mod index 046d213f0..ead42669d 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,7 @@ module github.com/nyaruka/mailroom +go 1.17 + require ( github.com/Masterminds/semver v1.5.0 github.com/apex/log v1.1.4 @@ -11,18 +13,18 @@ require ( github.com/go-chi/chi v4.1.2+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang/protobuf v1.4.0 - github.com/gomodule/redigo v2.0.0+incompatible + github.com/gomodule/redigo v1.8.8 github.com/gorilla/schema v1.1.0 - github.com/jmoiron/sqlx v1.2.0 + github.com/jmoiron/sqlx v1.3.4 github.com/kylelemons/godebug v1.1.0 // indirect - github.com/lib/pq v1.4.0 - github.com/mattn/go-sqlite3 v1.10.0 // indirect + github.com/lib/pq v1.10.4 github.com/nyaruka/ezconf v0.2.1 - github.com/nyaruka/gocommon v1.14.1 - github.com/nyaruka/goflow v0.140.1 + github.com/nyaruka/gocommon v1.17.0 + github.com/nyaruka/goflow v0.148.0 github.com/nyaruka/librato v1.0.0 github.com/nyaruka/logrus_sentry v0.8.2-0.20190129182604-c2962b80ba7d github.com/nyaruka/null v1.2.0 + github.com/nyaruka/redisx v0.2.1 github.com/olivere/elastic/v7 v7.0.22 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pkg/errors v0.9.1 @@ -40,13 +42,13 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/fatih/structs v1.0.0 // indirect github.com/go-mail/mail v2.3.1+incompatible // indirect - github.com/go-playground/locales v0.13.0 // indirect - github.com/go-playground/universal-translator v0.17.0 // indirect + github.com/go-playground/locales v0.14.0 // indirect + github.com/go-playground/universal-translator v0.18.0 // indirect github.com/gofrs/uuid v3.3.0+incompatible // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.1 // indirect - github.com/leodido/go-urn v1.2.0 // indirect + github.com/leodido/go-urn v1.2.1 // indirect github.com/mailru/easyjson v0.7.6 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect github.com/naoina/go-stringutil v0.1.0 // indirect @@ -61,5 +63,3 @@ require ( gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect ) - -go 1.17 diff --git a/go.sum b/go.sum index a15cf5a30..13c27bb57 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,7 @@ github.com/apex/logs v0.0.4/go.mod h1:XzxuLZ5myVHDy9SAmYpamKKRNApGj54PfYLcFrXqDw github.com/aphistic/golf v0.0.0-20180712155816-02c07f170c5a/go.mod h1:3NqKYiepwy8kCu4PNA+aP7WUV72eXWJeP9/r3/K9aLE= github.com/aphistic/sweet v0.2.0/go.mod h1:fWDlIh/isSE9n6EPsRmC0det+whmX6dJid3stzu0Xys= github.com/aws/aws-sdk-go v1.20.6/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/aws/aws-sdk-go v1.34.31/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= github.com/aws/aws-sdk-go v1.35.20/go.mod h1:tlPOdRjfxPBpNIwqDj61rmsnA85v9jc0Ps9+muhnW+k= github.com/aws/aws-sdk-go v1.40.56 h1:FM2yjR0UUYFzDTMx+mH9Vyw1k1EUUxsAFzk+BjkzANA= github.com/aws/aws-sdk-go v1.40.56/go.mod h1:585smgzpB/KqRA+K3y/NL/oYRqQvpNJYvLm+LY1U59Q= @@ -48,11 +49,10 @@ github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9 github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-mail/mail v2.3.1+incompatible h1:UzNOn0k5lpfVtO31cK3hn6I4VEVGhe3lX8AJBAxXExM= github.com/go-mail/mail v2.3.1+incompatible/go.mod h1:VPWjmmNyRsWXQZHVHT3g0YbIINUkSmuKOiLIDkWbL6M= -github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= -github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= -github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= -github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= -github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= +github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= +github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= +github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -73,8 +73,8 @@ github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrU github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0 h1:oOuy+ugB+P/kBdUnG5QaMXSIyJ1q38wWSojYCb3z5VQ= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= -github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= +github.com/gomodule/redigo v1.8.8 h1:f6cXq6RRfiyrOJEV7p3JhLDlmawGBVBBP1MggY8Mo4E= +github.com/gomodule/redigo v1.8.8/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -89,8 +89,8 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= -github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA= -github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= +github.com/jmoiron/sqlx v1.3.4 h1:wv+0IJZfL5z0uZoUjlpKgHkgaFSYD+r9CfrXjEXsO7w= +github.com/jmoiron/sqlx v1.3.4/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/jpillora/backoff v0.0.0-20180909062703-3050d21c67d7/go.mod h1:2iMrUgbbvHEiQClaW2NsSzMyGHqN+rDFqY705q49KG0= @@ -106,20 +106,20 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= -github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= +github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= -github.com/lib/pq v1.4.0 h1:TmtCFbH+Aw0AixwyttznSMQDgbR5Yed/Gg6S8Funrhc= -github.com/lib/pq v1.4.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.4 h1:SO9z7FRPzA03QhHKJrH5BXA6HU1rS4V2nIVrrNC1iYk= +github.com/lib/pq v1.10.4/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA= github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= -github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= @@ -132,18 +132,22 @@ github.com/naoina/toml v0.1.1 h1:PT/lllxVVN0gzzSqSlHEmP8MJB4MY2U7STGxiouV4X8= github.com/naoina/toml v0.1.1/go.mod h1:NBIhNtsFMo3G2szEBne+bO4gS192HuIYRqfvOWb4i1E= github.com/nyaruka/ezconf v0.2.1 h1:TDXWoqjqYya1uhou1mAJZg7rgFYL98EB0Tb3+BWtUh0= github.com/nyaruka/ezconf v0.2.1/go.mod h1:ey182kYkw2MIi4XiWe1FR/mzI33WCmTWuceDYYxgnQw= -github.com/nyaruka/gocommon v1.14.1 h1:/ScvLmg4zzVAuZ78TaENrvSEvW3WnUdqRd/t9hX7z7E= -github.com/nyaruka/gocommon v1.14.1/go.mod h1:R1Vr7PwrYCSu+vcU0t8t/5C4TsCwcWoqiuIQCxcMqxs= -github.com/nyaruka/goflow v0.140.1 h1:B/ikb/eOgqzEIoKWYjTSQtb5h3AHpnf/xrTS0H2lJLA= -github.com/nyaruka/goflow v0.140.1/go.mod h1:s3f7q2k6IKZicOcu2mu2EcuKgK3hava43Zb3cagtpVM= +github.com/nyaruka/gocommon v1.5.3/go.mod h1:2ZeBZF9yt20IaAJ4aC1ujojAsFhJBk2IuDvSl7KuQDw= +github.com/nyaruka/gocommon v1.17.0 h1:cTiDLSUgmYJ9OZw752jva0P2rz0utRtv5WGuKFc9kxw= +github.com/nyaruka/gocommon v1.17.0/go.mod h1:nmYyb7MZDM0iW4DYJKiBzfKuE9nbnx+xSHZasuIBOT0= +github.com/nyaruka/goflow v0.148.0 h1:1PEzdywQzmGodELpJ3jt2G0v9KUbDi0J1+9akKF6oEM= +github.com/nyaruka/goflow v0.148.0/go.mod h1:uJN4MWdW5Yw6bP5jMKpzmtwb2j5gDPMh/uJhbEcq9MY= github.com/nyaruka/librato v1.0.0 h1:Vznj9WCeC1yZXbBYyYp40KnbmXLbEkjKmHesV/v2SR0= github.com/nyaruka/librato v1.0.0/go.mod h1:pkRNLFhFurOz0QqBz6/DuTFhHHxAubWxs4Jx+J7yUgg= github.com/nyaruka/logrus_sentry v0.8.2-0.20190129182604-c2962b80ba7d h1:hyp9u36KIwbTCo2JAJ+TuJcJBc+UZzEig7RI/S5Dvkc= github.com/nyaruka/logrus_sentry v0.8.2-0.20190129182604-c2962b80ba7d/go.mod h1:FGdPJVDTNqbRAD+2RvnK9YoO2HcEW7ogSMPzc90b638= github.com/nyaruka/null v1.2.0 h1:uEbkyy4Z+zPB2Pr3ryQh/0N2965I9kEsXq/cGpyJ7PA= github.com/nyaruka/null v1.2.0/go.mod h1:HSAFbLNOaEhHnoU0VCveCPz0GDtJ3GEtFWhvnBNkhPE= +github.com/nyaruka/phonenumbers v1.0.58/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= github.com/nyaruka/phonenumbers v1.0.71 h1:itkCGhxkQkHrJ6OyZSApdjQVlPmrWs88MF283pPvbFU= github.com/nyaruka/phonenumbers v1.0.71/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= +github.com/nyaruka/redisx v0.2.1 h1:BavpQRCsK5xV2uxPdJJ26yVmjSo+q6bdjWqeNNf0s5w= +github.com/nyaruka/redisx v0.2.1/go.mod h1:cdbAm4y/+oFWu7qFzH2ERPeqRXJC2CtgRhwcBacM4Oc= github.com/olivere/elastic/v7 v7.0.22 h1:esBA6JJwvYgfms0EVlH7Z+9J4oQ/WUADF2y/nCNDw7s= github.com/olivere/elastic/v7 v7.0.22/go.mod h1:VDexNy9NjmtAkrjNoI7tImv7FR4tf5zUA3ickqu5Pc8= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -189,6 +193,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tj/assert v0.0.0-20171129193455-018094318fb0/go.mod h1:mZ9/Rh9oLWpLLDRpvE+3b7gP/C2YyLFYxNmcLnPTMe0= @@ -199,6 +204,7 @@ go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= @@ -213,6 +219,7 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200925080053-05aa5d4ee321/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20210614182718-04defd469f4e h1:XpT3nA5TvE525Ne3hInMh6+GETgn27Zfm9dxsThnX2Q= golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -229,12 +236,12 @@ golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= diff --git a/mailroom.go b/mailroom.go index 967075f41..b17b7c4ae 100644 --- a/mailroom.go +++ b/mailroom.go @@ -3,7 +3,6 @@ package mailroom import ( "context" "net/url" - "os" "strings" "sync" "time" @@ -165,8 +164,7 @@ func (mr *Mailroom) Start() error { // if we have a librato token, configure it if c.LibratoToken != "" { - host, _ := os.Hostname() - librato.Configure(c.LibratoUsername, c.LibratoToken, host, time.Second, mr.wg) + librato.Configure(c.LibratoUsername, c.LibratoToken, c.InstanceName, time.Second, mr.wg) librato.Start() } @@ -196,7 +194,12 @@ func (mr *Mailroom) Stop() error { mr.webserver.Stop() mr.wg.Wait() - mr.rt.ES.Stop() + + // stop ES client if we have one + if mr.rt.ES != nil { + mr.rt.ES.Stop() + } + logrus.Info("mailroom stopped") return nil } diff --git a/mailroom_test.dump b/mailroom_test.dump index 849956616..bdec91fcf 100644 Binary files a/mailroom_test.dump and b/mailroom_test.dump differ diff --git a/runtime/config.go b/runtime/config.go index 1c502a9fb..bcd19ab43 100644 --- a/runtime/config.go +++ b/runtime/config.go @@ -4,12 +4,18 @@ import ( "encoding/csv" "io" "net" + "os" "strings" "github.com/nyaruka/goflow/utils" "github.com/pkg/errors" + "gopkg.in/go-playground/validator.v9" ) +func init() { + utils.RegisterValidatorAlias("session_storage", "eq=db|eq=s3", func(e validator.FieldError) string { return "is not a valid session storage mode" }) +} + // Config is our top level configuration object type Config struct { DB string `validate:"url,startswith=postgres:" help:"URL for your Postgres database"` @@ -29,16 +35,19 @@ type Config struct { HandlerWorkers int `help:"the number of go routines that will be used to handle messages"` RetryPendingMessages bool `help:"whether to requeue pending messages older than five minutes to retry"` - WebhooksTimeout int `help:"the timeout in milliseconds for webhook calls from engine"` - WebhooksMaxRetries int `help:"the number of times to retry a failed webhook call"` - WebhooksMaxBodyBytes int `help:"the maximum size of bytes to a webhook call response body"` - WebhooksInitialBackoff int `help:"the initial backoff in milliseconds when retrying a failed webhook call"` - WebhooksBackoffJitter float64 `help:"the amount of jitter to apply to backoff times"` - SMTPServer string `help:"the smtp configuration for sending emails ex: smtp://user%40password@server:port/?from=foo%40gmail.com"` - DisallowedNetworks string `help:"comma separated list of IP addresses and networks which engine can't make HTTP calls to"` - MaxStepsPerSprint int `help:"the maximum number of steps allowed per engine sprint"` - MaxResumesPerSession int `help:"the maximum number of resumes allowed per engine session"` - MaxValueLength int `help:"the maximum size in characters for contact field values and run result values"` + WebhooksTimeout int `help:"the timeout in milliseconds for webhook calls from engine"` + WebhooksMaxRetries int `help:"the number of times to retry a failed webhook call"` + WebhooksMaxBodyBytes int `help:"the maximum size of bytes to a webhook call response body"` + WebhooksInitialBackoff int `help:"the initial backoff in milliseconds when retrying a failed webhook call"` + WebhooksBackoffJitter float64 `help:"the amount of jitter to apply to backoff times"` + WebhooksHealthyResponseLimit int `help:"the limit in milliseconds for webhook response to be considered healthy"` + + SMTPServer string `help:"the smtp configuration for sending emails ex: smtp://user%40password@server:port/?from=foo%40gmail.com"` + DisallowedNetworks string `help:"comma separated list of IP addresses and networks which engine can't make HTTP calls to"` + MaxStepsPerSprint int `help:"the maximum number of steps allowed per engine sprint"` + MaxResumesPerSession int `help:"the maximum number of resumes allowed per engine session"` + MaxValueLength int `help:"the maximum size in characters for contact field values and run result values"` + SessionStorage string `validate:"omitempty,session_storage" help:"where to store session output (s3|db)"` S3Endpoint string `help:"the S3 endpoint we will write attachments to"` S3Region string `help:"the S3 region we will write attachments to"` @@ -58,14 +67,17 @@ type Config struct { FCMKey string `help:"the FCM API key used to notify Android relayers to sync"` MailgunSigningKey string `help:"the signing key used to validate requests from mailgun"` - TimeoutTime int `help:"the amount of time to between every timeout queued"` - LogLevel string `help:"the logging level courier should use"` - UUIDSeed int `help:"seed to use for UUID generation in a testing environment"` - Version string `help:"the version of this mailroom install"` + InstanceName string `help:"the unique name of this instance used for analytics"` + TimeoutTime int `help:"the amount of time to between every timeout queued"` + LogLevel string `help:"the logging level courier should use"` + UUIDSeed int `help:"seed to use for UUID generation in a testing environment"` + Version string `help:"the version of this mailroom install"` } // NewDefaultConfig returns a new default configuration object func NewDefaultConfig() *Config { + hostname, _ := os.Hostname() + return &Config{ DB: "postgres://temba:temba@localhost/temba?sslmode=disable&Timezone=UTC", ReadonlyDB: "", @@ -80,16 +92,19 @@ func NewDefaultConfig() *Config { HandlerWorkers: 32, RetryPendingMessages: true, - WebhooksTimeout: 15000, - WebhooksMaxRetries: 2, - WebhooksMaxBodyBytes: 1024 * 1024, // 1MB - WebhooksInitialBackoff: 5000, - WebhooksBackoffJitter: 0.5, - SMTPServer: "", - DisallowedNetworks: `127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,169.254.0.0/16,fe80::/10`, - MaxStepsPerSprint: 100, - MaxResumesPerSession: 250, - MaxValueLength: 640, + WebhooksTimeout: 15000, + WebhooksMaxRetries: 2, + WebhooksMaxBodyBytes: 1024 * 1024, // 1MB + WebhooksInitialBackoff: 5000, + WebhooksBackoffJitter: 0.5, + WebhooksHealthyResponseLimit: 10000, + + SMTPServer: "", + DisallowedNetworks: `127.0.0.1,::1,10.0.0.0/8,172.16.0.0/12,192.168.0.0/16,169.254.0.0/16,fe80::/10`, + MaxStepsPerSprint: 100, + MaxResumesPerSession: 250, + MaxValueLength: 640, + SessionStorage: "db", S3Endpoint: "https://s3.amazonaws.com", S3Region: "us-east-1", @@ -103,10 +118,11 @@ func NewDefaultConfig() *Config { AWSAccessKeyID: "", AWSSecretAccessKey: "", - TimeoutTime: 60, - LogLevel: "error", - UUIDSeed: 0, - Version: "Dev", + InstanceName: hostname, + TimeoutTime: 60, + LogLevel: "error", + UUIDSeed: 0, + Version: "Dev", } } diff --git a/runtime/config_test.go b/runtime/config_test.go index 1ba5cfd79..f287a9646 100644 --- a/runtime/config_test.go +++ b/runtime/config_test.go @@ -17,7 +17,8 @@ func TestValidate(t *testing.T) { c.ReadonlyDB = "??" c.Redis = "??" c.Elastic = "??" - assert.EqualError(t, c.Validate(), "field 'DB' is not a valid URL, field 'ReadonlyDB' is not a valid URL, field 'Redis' is not a valid URL, field 'Elastic' is not a valid URL") + c.SessionStorage = "??" + assert.EqualError(t, c.Validate(), "field 'DB' is not a valid URL, field 'ReadonlyDB' is not a valid URL, field 'Redis' is not a valid URL, field 'Elastic' is not a valid URL, field 'SessionStorage' is not a valid session storage mode") c = runtime.NewDefaultConfig() c.DB = "mysql://temba:temba@localhost/temba" diff --git a/services/ivr/twiml/service.go b/services/ivr/twiml/service.go index d9bcd3060..bfa851e56 100644 --- a/services/ivr/twiml/service.go +++ b/services/ivr/twiml/service.go @@ -19,7 +19,6 @@ import ( "github.com/nyaruka/goflow/envs" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/goflow/flows/events" - "github.com/nyaruka/goflow/flows/routers/waits" "github.com/nyaruka/goflow/flows/routers/waits/hints" "github.com/nyaruka/goflow/utils" "github.com/nyaruka/mailroom/core/ivr" @@ -365,7 +364,7 @@ func (s *service) WriteSessionResponse(ctx context.Context, rt *runtime.Runtime, } // get our response - response, err := ResponseForSprint(rt.Config, number, resumeURL, session.Wait(), sprint.Events(), true) + response, err := ResponseForSprint(rt.Config, number, resumeURL, sprint.Events(), true) if err != nil { return errors.Wrap(err, "unable to build response for IVR call") } @@ -445,9 +444,10 @@ func twCalculateSignature(url string, form url.Values, authToken string) ([]byte // TWIML building utilities -func ResponseForSprint(cfg *runtime.Config, number urns.URN, resumeURL string, w flows.ActivatedWait, es []flows.Event, indent bool) (string, error) { +func ResponseForSprint(cfg *runtime.Config, number urns.URN, resumeURL string, es []flows.Event, indent bool) (string, error) { r := &Response{} commands := make([]interface{}, 0) + hasWait := false for _, e := range es { switch event := e.(type) { @@ -467,14 +467,10 @@ func ResponseForSprint(cfg *runtime.Config, number urns.URN, resumeURL string, w commands = append(commands, Play{URL: a.URL()}) } } - } - } - if w != nil { - switch wait := w.(type) { - - case *waits.ActivatedMsgWait: - switch hint := wait.Hint().(type) { + case *events.MsgWaitEvent: + hasWait = true + switch hint := event.Hint.(type) { case *hints.DigitsHint: resumeURL = resumeURL + "&wait_type=gather" gather := &Gather{ @@ -496,21 +492,18 @@ func ResponseForSprint(cfg *runtime.Config, number urns.URN, resumeURL string, w r.Commands = commands default: - return "", errors.Errorf("unable to use hint in IVR call, unknow type: %s", wait.Hint().Type()) + return "", errors.Errorf("unable to use hint in IVR call, unknown type: %s", event.Hint.Type()) } - case *waits.ActivatedDialWait: - dial := Dial{Action: resumeURL + "&wait_type=dial", Number: wait.URN().Path()} - if w.TimeoutSeconds() != nil { - dial.Timeout = *w.TimeoutSeconds() - } + case *events.DialWaitEvent: + hasWait = true + dial := Dial{Action: resumeURL + "&wait_type=dial", Number: event.URN.Path()} commands = append(commands, dial) r.Commands = commands - - default: - return "", fmt.Errorf("unable to use wait type in Twilio call: %x", w) } - } else { + } + + if !hasWait { // no wait? call is over, hang up commands = append(commands, Hangup{}) r.Commands = commands diff --git a/services/ivr/twiml/service_test.go b/services/ivr/twiml/service_test.go index e1f8bbd1a..447872114 100644 --- a/services/ivr/twiml/service_test.go +++ b/services/ivr/twiml/service_test.go @@ -6,12 +6,12 @@ import ( "strconv" "strings" "testing" + "time" "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/gocommon/uuids" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/flows/events" - "github.com/nyaruka/goflow/flows/routers/waits" "github.com/nyaruka/goflow/flows/routers/waits/hints" "github.com/nyaruka/goflow/utils" "github.com/nyaruka/mailroom/services/ivr/twiml" @@ -25,6 +25,7 @@ func TestResponseForSprint(t *testing.T) { _, rt, _, _ := testsuite.Get() urn := urns.URN("tel:+12067799294") + expiresOn := time.Now().Add(time.Hour) channelRef := assets.NewChannelReference(assets.ChannelUUID(uuids.New()), "Twilio Channel") resumeURL := "http://temba.io/resume?session=1" @@ -34,33 +35,37 @@ func TestResponseForSprint(t *testing.T) { defer func() { rt.Config.AttachmentDomain = "" }() tcs := []struct { - Events []flows.Event - Wait flows.ActivatedWait - Expected string + events []flows.Event + expected string }{ { - []flows.Event{events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", nil, nil, nil, flows.NilMsgTopic))}, - nil, + []flows.Event{ + events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", nil, nil, nil, flows.NilMsgTopic)), + }, `hello world`, }, { - []flows.Event{events.NewIVRCreated(flows.NewIVRMsgOut(urn, channelRef, "hello world", "eng", ""))}, - nil, + []flows.Event{ + events.NewIVRCreated(flows.NewIVRMsgOut(urn, channelRef, "hello world", "eng", "")), + }, `hello world`, }, { - []flows.Event{events.NewIVRCreated(flows.NewIVRMsgOut(urn, channelRef, "hello world", "ben", ""))}, - nil, + []flows.Event{ + events.NewIVRCreated(flows.NewIVRMsgOut(urn, channelRef, "hello world", "ben", "")), + }, `hello world`, }, { - []flows.Event{events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", []utils.Attachment{utils.Attachment("audio:/recordings/foo.wav")}, nil, nil, flows.NilMsgTopic))}, - nil, + []flows.Event{ + events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", []utils.Attachment{utils.Attachment("audio:/recordings/foo.wav")}, nil, nil, flows.NilMsgTopic)), + }, `https://mailroom.io/recordings/foo.wav`, }, { - []flows.Event{events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", []utils.Attachment{utils.Attachment("audio:https://temba.io/recordings/foo.wav")}, nil, nil, flows.NilMsgTopic))}, - nil, + []flows.Event{ + events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", []utils.Attachment{utils.Attachment("audio:https://temba.io/recordings/foo.wav")}, nil, nil, flows.NilMsgTopic)), + }, `https://temba.io/recordings/foo.wav`, }, { @@ -68,30 +73,41 @@ func TestResponseForSprint(t *testing.T) { events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", nil, nil, nil, flows.NilMsgTopic)), events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "goodbye", nil, nil, nil, flows.NilMsgTopic)), }, - nil, `hello worldgoodbye`, }, { - []flows.Event{events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "enter a number", nil, nil, nil, flows.NilMsgTopic))}, - waits.NewActivatedMsgWait(nil, hints.NewFixedDigitsHint(1)), + []flows.Event{ + events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "enter a number", nil, nil, nil, flows.NilMsgTopic)), + events.NewMsgWait(nil, nil, hints.NewFixedDigitsHint(1)), + }, `enter a numberhttp://temba.io/resume?session=1&wait_type=gather&timeout=true`, }, { - []flows.Event{events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "enter a number, then press #", nil, nil, nil, flows.NilMsgTopic))}, - waits.NewActivatedMsgWait(nil, hints.NewTerminatedDigitsHint("#")), + []flows.Event{ + events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "enter a number, then press #", nil, nil, nil, flows.NilMsgTopic)), + events.NewMsgWait(nil, nil, hints.NewTerminatedDigitsHint("#")), + }, `enter a number, then press #http://temba.io/resume?session=1&wait_type=gather&timeout=true`, }, { - []flows.Event{events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "say something", nil, nil, nil, flows.NilMsgTopic))}, - waits.NewActivatedMsgWait(nil, hints.NewAudioHint()), + []flows.Event{ + events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "say something", nil, nil, nil, flows.NilMsgTopic)), + events.NewMsgWait(nil, nil, hints.NewAudioHint()), + }, `say somethinghttp://temba.io/resume?session=1&wait_type=record&empty=true`, }, + { + []flows.Event{ + events.NewDialWait(urns.URN(`tel:+1234567890`), &expiresOn), + }, + `+1234567890`, + }, } for i, tc := range tcs { - response, err := twiml.ResponseForSprint(rt.Config, urn, resumeURL, tc.Wait, tc.Events, false) + response, err := twiml.ResponseForSprint(rt.Config, urn, resumeURL, tc.events, false) assert.NoError(t, err, "%d: unexpected error") - assert.Equal(t, xml.Header+tc.Expected, response, "%d: unexpected response", i) + assert.Equal(t, xml.Header+tc.expected, response, "%d: unexpected response", i) } } diff --git a/services/ivr/vonage/service.go b/services/ivr/vonage/service.go index a6cae08ff..47eb595ed 100644 --- a/services/ivr/vonage/service.go +++ b/services/ivr/vonage/service.go @@ -19,11 +19,11 @@ import ( "time" "github.com/nyaruka/gocommon/httpx" + "github.com/nyaruka/gocommon/jsonx" "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/gocommon/uuids" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/goflow/flows/events" - "github.com/nyaruka/goflow/flows/routers/waits" "github.com/nyaruka/goflow/flows/routers/waits/hints" "github.com/nyaruka/goflow/utils" "github.com/nyaruka/mailroom/core/ivr" @@ -389,7 +389,7 @@ func (s *service) RequestCall(number urns.URN, resumeURL string, statusURL strin } if trace.Response.StatusCode != http.StatusCreated { - return ivr.NilCallID, trace, errors.Errorf("received non 200 status for call start: %d", trace.Response.StatusCode) + return ivr.NilCallID, trace, errors.Errorf("received non 201 status for call start: %d", trace.Response.StatusCode) } // parse out our call sid @@ -603,7 +603,7 @@ func (s *service) WriteSessionResponse(ctx context.Context, rt *runtime.Runtime, } // get our response - response, err := s.responseForSprint(ctx, rt.RP, channel, conn, resumeURL, session.Wait(), sprint.Events()) + response, err := s.responseForSprint(ctx, rt.RP, channel, conn, resumeURL, sprint.Events()) if err != nil { return errors.Wrap(err, "unable to build response for IVR call") } @@ -653,11 +653,7 @@ func (s *service) MakeEmptyResponseBody(msg string) []byte { } func (s *service) makeRequest(method string, sendURL string, body interface{}) (*httpx.Trace, error) { - bb, err := json.Marshal(body) - if err != nil { - return nil, errors.Wrapf(err, "error json encoding request") - } - + bb := jsonx.MustMarshal(body) req, _ := http.NewRequest(method, sendURL, bytes.NewReader(bb)) token, err := s.generateToken() if err != nil { @@ -731,14 +727,22 @@ func (s *service) generateToken() (string, error) { // NCCO building utilities -func (s *service) responseForSprint(ctx context.Context, rp *redis.Pool, channel *models.Channel, conn *models.ChannelConnection, resumeURL string, w flows.ActivatedWait, es []flows.Event) (string, error) { +func (s *service) responseForSprint(ctx context.Context, rp *redis.Pool, channel *models.Channel, conn *models.ChannelConnection, resumeURL string, es []flows.Event) (string, error) { actions := make([]interface{}, 0, 1) waitActions := make([]interface{}, 0, 1) - if w != nil { - switch wait := w.(type) { - case *waits.ActivatedMsgWait: - switch hint := wait.Hint().(type) { + var waitEvent flows.Event + for _, e := range es { + switch event := e.(type) { + case *events.MsgWaitEvent, *events.DialWaitEvent: + waitEvent = event + } + } + + if waitEvent != nil { + switch wait := waitEvent.(type) { + case *events.MsgWaitEvent: + switch hint := wait.Hint.(type) { case *hints.DigitsHint: eventURL := resumeURL + "&wait_type=gather" eventURL = eventURL + "&sig=" + url.QueryEscape(s.calculateSignature(eventURL)) @@ -794,10 +798,10 @@ func (s *service) responseForSprint(ctx context.Context, rp *redis.Pool, channel waitActions = append(waitActions, input) default: - return "", errors.Errorf("unable to use wait in IVR call, unknow hint type: %s", wait.Hint().Type()) + return "", errors.Errorf("unable to use wait in IVR call, unknow hint type: %s", wait.Hint.Type()) } - case *waits.ActivatedDialWait: + case *events.DialWaitEvent: // Vonage handles forwards a bit differently. We have to create a new call to the forwarded number, then // join the current call with the call we are starting. // @@ -814,12 +818,9 @@ func (s *service) responseForSprint(ctx context.Context, rp *redis.Pool, channel // create our outbound call with the same conversation UUID call := CallRequest{} - call.To = append(call.To, Phone{Type: "phone", Number: strings.TrimLeft(wait.URN().Path(), "+")}) + call.To = append(call.To, Phone{Type: "phone", Number: strings.TrimLeft(wait.URN.Path(), "+")}) call.From = Phone{Type: "phone", Number: strings.TrimLeft(channel.Address(), "+")} call.NCCO = append(call.NCCO, NCCO{Action: "conversation", Name: conversationUUID}) - if wait.TimeoutSeconds() != nil { - call.RingingTimer = *wait.TimeoutSeconds() - } trace, err := s.makeRequest(http.MethodPost, s.callURL, call) logrus.WithField("trace", trace).Debug("initiated new call for transfer") @@ -849,9 +850,6 @@ func (s *service) responseForSprint(ctx context.Context, rp *redis.Pool, channel return "", errors.Wrapf(err, "error inserting transfer ID into redis") } logrus.WithField("transferUUID", transferUUID).WithField("callID", conn.ExternalID()).WithField("redisKey", redisKey).WithField("redisValue", redisValue).Debug("saved away call id") - - default: - return "", errors.Errorf("unable to use wait in IVR call, unknow wait type: %s", w) } } diff --git a/services/ivr/vonage/service_test.go b/services/ivr/vonage/service_test.go index 64b151496..11b4c8552 100644 --- a/services/ivr/vonage/service_test.go +++ b/services/ivr/vonage/service_test.go @@ -3,13 +3,14 @@ package vonage import ( "net/http" "testing" + "time" + "github.com/nyaruka/gocommon/httpx" "github.com/nyaruka/gocommon/urns" "github.com/nyaruka/gocommon/uuids" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/goflow/flows/events" - "github.com/nyaruka/goflow/flows/routers/waits" "github.com/nyaruka/goflow/flows/routers/waits/hints" "github.com/nyaruka/goflow/utils" "github.com/nyaruka/mailroom/core/models" @@ -17,17 +18,25 @@ import ( "github.com/nyaruka/mailroom/testsuite/testdata" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestResponseForSprint(t *testing.T) { ctx, rt, db, rp := testsuite.Get() + rc := rp.Get() + defer rc.Close() defer testsuite.Reset(testsuite.ResetAll) - rc := rp.Get() - defer rc.Close() + defer httpx.SetRequestor(httpx.DefaultRequestor) + httpx.SetRequestor(httpx.NewMockRequestor(map[string][]httpx.MockResponse{ + "https://api.nexmo.com/v1/calls": { + httpx.NewMockResponse(201, nil, `{"uuid": "63f61863-4a51-4f6b-86e1-46edebcf9356", "status": "started", "direction": "outbound"}`), + }, + })) urn := urns.URN("tel:+12067799294") + expiresOn := time.Now().Add(time.Hour) channelRef := assets.NewChannelReference(testdata.VonageChannel.UUID, "Vonage Channel") resumeURL := "http://temba.io/resume?session=1" @@ -42,36 +51,41 @@ func TestResponseForSprint(t *testing.T) { uuids.SetGenerator(uuids.NewSeededGenerator(0)) oa, err := models.GetOrgAssets(ctx, rt, testdata.Org1.ID) - assert.NoError(t, err) + require.NoError(t, err) channel := oa.ChannelByUUID(testdata.VonageChannel.UUID) assert.NotNil(t, channel) p, err := NewServiceFromChannel(http.DefaultClient, channel) - assert.NoError(t, err) + require.NoError(t, err) provider := p.(*service) + conn, err := models.InsertIVRConnection(ctx, db, testdata.Org1.ID, testdata.VonageChannel.ID, models.NilStartID, testdata.Bob.ID, testdata.Bob.URNID, models.ConnectionDirectionOut, models.ConnectionStatusInProgress, "EX123") + require.NoError(t, err) + indentMarshal = false tcs := []struct { - Events []flows.Event - Wait flows.ActivatedWait - Expected string + events []flows.Event + expected string }{ { - []flows.Event{events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", nil, nil, nil, flows.NilMsgTopic))}, - nil, + []flows.Event{ + events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", nil, nil, nil, flows.NilMsgTopic)), + }, `[{"action":"talk","text":"hello world"}]`, }, { - []flows.Event{events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", []utils.Attachment{utils.Attachment("audio:/recordings/foo.wav")}, nil, nil, flows.NilMsgTopic))}, - nil, + []flows.Event{ + events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", []utils.Attachment{utils.Attachment("audio:/recordings/foo.wav")}, nil, nil, flows.NilMsgTopic)), + }, `[{"action":"stream","streamUrl":["/recordings/foo.wav"]}]`, }, { - []flows.Event{events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", []utils.Attachment{utils.Attachment("audio:https://temba.io/recordings/foo.wav")}, nil, nil, flows.NilMsgTopic))}, - nil, + []flows.Event{ + events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", []utils.Attachment{utils.Attachment("audio:https://temba.io/recordings/foo.wav")}, nil, nil, flows.NilMsgTopic)), + }, `[{"action":"stream","streamUrl":["https://temba.io/recordings/foo.wav"]}]`, }, { @@ -79,29 +93,40 @@ func TestResponseForSprint(t *testing.T) { events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "hello world", nil, nil, nil, flows.NilMsgTopic)), events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "goodbye", nil, nil, nil, flows.NilMsgTopic)), }, - nil, `[{"action":"talk","text":"hello world"},{"action":"talk","text":"goodbye"}]`, }, { - []flows.Event{events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "enter a number", nil, nil, nil, flows.NilMsgTopic))}, - waits.NewActivatedMsgWait(nil, hints.NewFixedDigitsHint(1)), + []flows.Event{ + events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "enter a number", nil, nil, nil, flows.NilMsgTopic)), + events.NewMsgWait(nil, nil, hints.NewFixedDigitsHint(1)), + }, `[{"action":"talk","text":"enter a number","bargeIn":true},{"action":"input","maxDigits":1,"submitOnHash":true,"timeOut":30,"eventUrl":["http://temba.io/resume?session=1\u0026wait_type=gather\u0026sig=OjsMUDhaBTUVLq1e6I4cM0SKYpk%3D"],"eventMethod":"POST"}]`, }, { - []flows.Event{events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "enter a number, then press #", nil, nil, nil, flows.NilMsgTopic))}, - waits.NewActivatedMsgWait(nil, hints.NewTerminatedDigitsHint("#")), + []flows.Event{ + events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "enter a number, then press #", nil, nil, nil, flows.NilMsgTopic)), + events.NewMsgWait(nil, nil, hints.NewTerminatedDigitsHint("#")), + }, `[{"action":"talk","text":"enter a number, then press #","bargeIn":true},{"action":"input","maxDigits":20,"submitOnHash":true,"timeOut":30,"eventUrl":["http://temba.io/resume?session=1\u0026wait_type=gather\u0026sig=OjsMUDhaBTUVLq1e6I4cM0SKYpk%3D"],"eventMethod":"POST"}]`, }, { - []flows.Event{events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "say something", nil, nil, nil, flows.NilMsgTopic))}, - waits.NewActivatedMsgWait(nil, hints.NewAudioHint()), + []flows.Event{ + events.NewIVRCreated(flows.NewMsgOut(urn, channelRef, "say something", nil, nil, nil, flows.NilMsgTopic)), + events.NewMsgWait(nil, nil, hints.NewAudioHint()), + }, `[{"action":"talk","text":"say something"},{"action":"record","endOnKey":"#","timeOut":600,"endOnSilence":5,"eventUrl":["http://temba.io/resume?session=1\u0026wait_type=recording_url\u0026recording_uuid=f3ede2d6-becc-4ea3-ae5e-88526a9f4a57\u0026sig=Am9z7fXyU3SPCZagkSpddZSi6xY%3D"],"eventMethod":"POST"},{"action":"input","submitOnHash":true,"timeOut":1,"eventUrl":["http://temba.io/resume?session=1\u0026wait_type=record\u0026recording_uuid=f3ede2d6-becc-4ea3-ae5e-88526a9f4a57\u0026sig=fX1RhjcJNN4xYaiojVYakaz5F%2Fk%3D"],"eventMethod":"POST"}]`, }, + { + []flows.Event{ + events.NewDialWait(urns.URN(`tel:+1234567890`), &expiresOn), + }, + `[{"action":"conversation","name":"8bcb9ef2-d4a6-4314-b68d-6d299761ea9e"}]`, + }, } for i, tc := range tcs { - response, err := provider.responseForSprint(ctx, rp, channel, nil, resumeURL, tc.Wait, tc.Events) + response, err := provider.responseForSprint(ctx, rp, channel, conn, resumeURL, tc.events) assert.NoError(t, err, "%d: unexpected error") - assert.Equal(t, tc.Expected, response, "%d: unexpected response", i) + assert.Equal(t, tc.expected, response, "%d: unexpected response", i) } } diff --git a/testsuite/assert.go b/testsuite/assert.go index c00d90d4e..0a15e2f28 100644 --- a/testsuite/assert.go +++ b/testsuite/assert.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/gomodule/redigo/redis" - "github.com/jmoiron/sqlx" "github.com/nyaruka/gocommon/jsonx" "github.com/nyaruka/goflow/test" "github.com/nyaruka/mailroom/core/models" @@ -29,7 +28,7 @@ func AssertCourierQueues(t *testing.T, expected map[string][]int, errMsg ...inte actual[queueKey] = make([]int, size) if size > 0 { - results, err := redis.Values(rc.Do("ZPOPMAX", queueKey, size)) + results, err := redis.Values(rc.Do("ZRANGE", queueKey, 0, -1, "WITHSCORES")) require.NoError(t, err) require.Equal(t, int(size*2), len(results)) // result is (item, score, item, score, ...) @@ -61,43 +60,3 @@ func AssertContactTasks(t *testing.T, orgID models.OrgID, contactID models.Conta test.AssertEqualJSON(t, expectedJSON, actualJSON, "") } - -// AssertQuery creates a new query on which one can assert things -func AssertQuery(t *testing.T, db *sqlx.DB, sql string, args ...interface{}) *Query { - return &Query{t, db, sql, args} -} - -type Query struct { - t *testing.T - db *sqlx.DB - sql string - args []interface{} -} - -func (q *Query) Returns(expected interface{}, msgAndArgs ...interface{}) { - q.t.Helper() - - // get a variable of same type to hold actual result - actual := expected - - err := q.db.Get(&actual, q.sql, q.args...) - assert.NoError(q.t, err, msgAndArgs...) - - // not sure why but if you pass an int you get back an int64.. - switch expected.(type) { - case int: - actual = int(actual.(int64)) - } - - assert.Equal(q.t, expected, actual, msgAndArgs...) -} - -func (q *Query) Columns(expected map[string]interface{}, msgAndArgs ...interface{}) { - q.t.Helper() - - actual := make(map[string]interface{}, len(expected)) - - err := q.db.QueryRowx(q.sql, q.args...).MapScan(actual) - assert.NoError(q.t, err, msgAndArgs...) - assert.Equal(q.t, expected, actual, msgAndArgs...) -} diff --git a/testsuite/db.go b/testsuite/db.go index 36fcf5452..2087ce43f 100644 --- a/testsuite/db.go +++ b/testsuite/db.go @@ -62,6 +62,13 @@ func (d *MockDB) NamedExecContext(ctx context.Context, query string, arg interfa return d.real.NamedExecContext(ctx, query, arg) } +func (d *MockDB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + if err := d.check("SelectContext"); err != nil { + return err + } + return d.real.SelectContext(ctx, dest, query, args...) +} + func (d *MockDB) GetContext(ctx context.Context, value interface{}, query string, args ...interface{}) error { if err := d.check("GetContext"); err != nil { return err diff --git a/testsuite/testdata/campaigns.go b/testsuite/testdata/campaigns.go new file mode 100644 index 000000000..754e41463 --- /dev/null +++ b/testsuite/testdata/campaigns.go @@ -0,0 +1,52 @@ +package testdata + +import ( + "time" + + "github.com/jmoiron/sqlx" + "github.com/nyaruka/gocommon/uuids" + "github.com/nyaruka/mailroom/core/models" +) + +type Campaign struct { + ID models.CampaignID + UUID models.CampaignUUID +} + +type CampaignEvent struct { + ID models.CampaignEventID +} + +func InsertCampaign(db *sqlx.DB, org *Org, name string, group *Group) *Campaign { + uuid := models.CampaignUUID(uuids.New()) + var id models.CampaignID + must(db.Get(&id, + `INSERT INTO campaigns_campaign(uuid, org_id, name, group_id, is_archived, is_active, created_on, modified_on, created_by_id, modified_by_id) + VALUES($1, $2, $3, $4, FALSE, TRUE, NOW(), NOW(), 1, 1) RETURNING id`, uuid, org.ID, name, group.ID, + )) + return &Campaign{id, uuid} +} + +func InsertCampaignFlowEvent(db *sqlx.DB, campaign *Campaign, flow *Flow, relativeTo *Field, offset int, unit string) *CampaignEvent { + uuid := models.CampaignEventUUID(uuids.New()) + var id models.CampaignEventID + must(db.Get(&id, + `INSERT INTO campaigns_campaignevent( + uuid, campaign_id, event_type, flow_id, relative_to_id, "offset", unit, delivery_hour, start_mode, + is_active, created_on, modified_on, created_by_id, modified_by_id + ) VALUES( + $1, $2, 'F', $3, $4, $5, $6, -1, 'I', + TRUE, NOW(), NOW(), 1, 1 + ) RETURNING id`, + uuid, campaign.ID, flow.ID, relativeTo.ID, offset, unit, + )) + return &CampaignEvent{id} +} + +func InsertEventFire(db *sqlx.DB, contact *Contact, event *CampaignEvent, scheduled time.Time) models.FireID { + var id models.FireID + must(db.Get(&id, + `INSERT INTO campaigns_eventfire(contact_id, event_id, scheduled) VALUES ($1, $2, $3) RETURNING id;`, contact.ID, event.ID, scheduled, + )) + return id +} diff --git a/testsuite/testdata/channels.go b/testsuite/testdata/channels.go index 0abf0b191..d28778809 100644 --- a/testsuite/testdata/channels.go +++ b/testsuite/testdata/channels.go @@ -25,3 +25,13 @@ func InsertChannel(db *sqlx.DB, org *Org, channelType, name string, schemes []st )) return &Channel{id, uuid} } + +// InsertConnection inserts a channel connection +func InsertConnection(db *sqlx.DB, org *Org, channel *Channel, contact *Contact) models.ConnectionID { + var id models.ConnectionID + must(db.Get(&id, + `INSERT INTO channels_channelconnection(created_on, modified_on, external_id, status, direction, connection_type, error_count, org_id, channel_id, contact_id, contact_urn_id) + VALUES(NOW(), NOW(), 'ext1', 'I', 'I', 'V', 0, $1, $2, $3, $4) RETURNING id`, org.ID, channel.ID, contact.ID, contact.URNID, + )) + return id +} diff --git a/testsuite/testdata/constants.go b/testsuite/testdata/constants.go index 6b5bcfc24..fba4ba651 100644 --- a/testsuite/testdata/constants.go +++ b/testsuite/testdata/constants.go @@ -1,44 +1,17 @@ package testdata import ( - "github.com/nyaruka/gocommon/uuids" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/mailroom/core/models" ) // Constants used in tests, these are tied to the DB created by the RapidPro `mailroom_db` management command. -type Org struct { - ID models.OrgID - UUID uuids.UUID -} - -type User struct { - ID models.UserID - Email string -} - -func (u *User) SafeID() models.UserID { - if u != nil { - return u.ID - } - return models.NilUserID -} - type Classifier struct { ID models.ClassifierID UUID assets.ClassifierUUID } -type Campaign struct { - ID models.CampaignID - UUID models.CampaignUUID -} - -type CampaignEvent struct { - ID models.CampaignEventID -} - var Org1 = &Org{1, "bf0514a5-9407-44c9-b0f9-3f36f9c18414"} var Admin = &User{3, "admin1@nyaruka.com"} var Editor = &User{4, "editor1@nyaruka.com"} diff --git a/testsuite/testdata/flows.go b/testsuite/testdata/flows.go index fdccee3db..f11bdc200 100644 --- a/testsuite/testdata/flows.go +++ b/testsuite/testdata/flows.go @@ -3,6 +3,7 @@ package testdata import ( "time" + "github.com/buger/jsonparser" "github.com/nyaruka/gocommon/uuids" "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/flows" @@ -17,6 +18,29 @@ type Flow struct { UUID assets.FlowUUID } +func (f *Flow) Reference() *assets.FlowReference { + return &assets.FlowReference{UUID: f.UUID, Name: ""} +} + +// InsertFlow inserts a flow +func InsertFlow(db *sqlx.DB, org *Org, definition []byte) *Flow { + uuid, err := jsonparser.GetString(definition, "uuid") + if err != nil { + panic(err) + } + + var id models.FlowID + must(db.Get(&id, + `INSERT INTO flows_flow(uuid, org_id, name, flow_type, version_number, expires_after_minutes, ignore_triggers, has_issues, is_active, is_archived, is_system, created_by_id, created_on, modified_by_id, modified_on, saved_on, saved_by_id) + VALUES($1, $2, 'Test', 'M', 1, 10, FALSE, FALSE, TRUE, FALSE, FALSE, $3, NOW(), $3, NOW(), NOW(), $3) RETURNING id`, uuid, org.ID, Admin.ID, + )) + + db.MustExec(`INSERT INTO flows_flowrevision(flow_id, definition, spec_version, revision, is_active, created_by_id, created_on, modified_by_id, modified_on) + VALUES($1, $2, '13.1.0', 1, TRUE, $3, NOW(), $3, NOW())`, id, definition, Admin.ID) + + return &Flow{ID: id, UUID: assets.FlowUUID(uuid)} +} + // InsertFlowStart inserts a flow start func InsertFlowStart(db *sqlx.DB, org *Org, flow *Flow, contacts []*Contact) models.StartID { var id models.StartID @@ -33,11 +57,20 @@ func InsertFlowStart(db *sqlx.DB, org *Org, flow *Flow, contacts []*Contact) mod } // InsertFlowSession inserts a flow session -func InsertFlowSession(db *sqlx.DB, org *Org, contact *Contact, status models.SessionStatus, timeoutOn *time.Time) models.SessionID { +func InsertFlowSession(db *sqlx.DB, org *Org, contact *Contact, sessionType models.FlowType, status models.SessionStatus, currentFlow *Flow, connectionID models.ConnectionID, timeoutOn *time.Time) models.SessionID { + now := time.Now() + tomorrow := now.Add(time.Hour * 24) + + var waitStartedOn, waitExpiresOn *time.Time + if status == models.SessionStatusWaiting { + waitStartedOn = &now + waitExpiresOn = &tomorrow + } + var id models.SessionID must(db.Get(&id, - `INSERT INTO flows_flowsession(uuid, org_id, contact_id, status, responded, created_on, timeout_on, session_type) - VALUES($1, $2, $3, $4, TRUE, NOW(), $5, 'M') RETURNING id`, uuids.New(), org.ID, contact.ID, status, timeoutOn, + `INSERT INTO flows_flowsession(uuid, org_id, contact_id, status, responded, created_on, session_type, current_flow_id, connection_id, timeout_on, wait_started_on, wait_expires_on, wait_resume_on_expire) + VALUES($1, $2, $3, $4, TRUE, NOW(), $5, $6, $7, $8, $9, $10, FALSE) RETURNING id`, uuids.New(), org.ID, contact.ID, status, sessionType, currentFlow.ID, connectionID, timeoutOn, waitStartedOn, waitExpiresOn, )) return id } diff --git a/testsuite/testdata/msgs.go b/testsuite/testdata/msgs.go index 2614f1a05..1a3f1c8a0 100644 --- a/testsuite/testdata/msgs.go +++ b/testsuite/testdata/msgs.go @@ -37,8 +37,24 @@ func InsertIncomingMsg(db *sqlx.DB, org *Org, channel *Channel, contact *Contact } // InsertOutgoingMsg inserts an outgoing message -func InsertOutgoingMsg(db *sqlx.DB, org *Org, channel *Channel, contact *Contact, text string, attachments []utils.Attachment, status models.MsgStatus) *flows.MsgOut { - msg := flows.NewMsgOut(contact.URN, assets.NewChannelReference(channel.UUID, ""), text, attachments, nil, nil, flows.NilMsgTopic) +func InsertOutgoingMsg(db *sqlx.DB, org *Org, channel *Channel, contact *Contact, text string, attachments []utils.Attachment, status models.MsgStatus, highPriority bool) *flows.MsgOut { + return insertOutgoingMsg(db, org, channel, contact, text, attachments, status, highPriority, 0, nil) +} + +// InsertErroredOutgoingMsg inserts an ERRORED(E) outgoing message +func InsertErroredOutgoingMsg(db *sqlx.DB, org *Org, channel *Channel, contact *Contact, text string, errorCount int, nextAttempt time.Time, highPriority bool) *flows.MsgOut { + return insertOutgoingMsg(db, org, channel, contact, text, nil, models.MsgStatusErrored, highPriority, errorCount, &nextAttempt) +} + +func insertOutgoingMsg(db *sqlx.DB, org *Org, channel *Channel, contact *Contact, text string, attachments []utils.Attachment, status models.MsgStatus, highPriority bool, errorCount int, nextAttempt *time.Time) *flows.MsgOut { + var channelRef *assets.ChannelReference + var channelID models.ChannelID + if channel != nil { + channelRef = assets.NewChannelReference(channel.UUID, "") + channelID = channel.ID + } + + msg := flows.NewMsgOut(contact.URN, channelRef, text, attachments, nil, nil, flows.NilMsgTopic) var sentOn *time.Time if status == models.MsgStatusWired || status == models.MsgStatusSent || status == models.MsgStatusDelivered { @@ -48,8 +64,9 @@ func InsertOutgoingMsg(db *sqlx.DB, org *Org, channel *Channel, contact *Contact var id flows.MsgID must(db.Get(&id, - `INSERT INTO msgs_msg(uuid, text, attachments, created_on, direction, status, visibility, msg_count, error_count, next_attempt, contact_id, contact_urn_id, org_id, channel_id, sent_on) - VALUES($1, $2, $3, NOW(), 'O', $4, 'V', 1, 0, NOW(), $5, $6, $7, $8, $9) RETURNING id`, msg.UUID(), text, pq.Array(attachments), status, contact.ID, contact.URNID, org.ID, channel.ID, sentOn, + `INSERT INTO msgs_msg(uuid, text, attachments, created_on, direction, status, visibility, contact_id, contact_urn_id, org_id, channel_id, sent_on, msg_count, error_count, next_attempt, high_priority) + VALUES($1, $2, $3, NOW(), 'O', $4, 'V', $5, $6, $7, $8, $9, 1, $10, $11, $12) RETURNING id`, + msg.UUID(), text, pq.Array(attachments), status, contact.ID, contact.URNID, org.ID, channelID, sentOn, errorCount, nextAttempt, highPriority, )) msg.SetID(id) return msg diff --git a/testsuite/testdata/orgs.go b/testsuite/testdata/orgs.go new file mode 100644 index 000000000..b45d488d3 --- /dev/null +++ b/testsuite/testdata/orgs.go @@ -0,0 +1,32 @@ +package testdata + +import ( + "context" + + "github.com/nyaruka/gocommon/uuids" + "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/runtime" +) + +type Org struct { + ID models.OrgID + UUID uuids.UUID +} + +func (o *Org) Load(rt *runtime.Runtime) *models.OrgAssets { + oa, err := models.GetOrgAssets(context.Background(), rt, o.ID) + must(err) + return oa +} + +type User struct { + ID models.UserID + Email string +} + +func (u *User) SafeID() models.UserID { + if u != nil { + return u.ID + } + return models.NilUserID +} diff --git a/testsuite/testsuite.go b/testsuite/testsuite.go index 6dd46c4fd..2f16e3da6 100644 --- a/testsuite/testsuite.go +++ b/testsuite/testsuite.go @@ -7,10 +7,14 @@ import ( "os/exec" "path" "strings" + "testing" + "github.com/nyaruka/gocommon/jsonx" "github.com/nyaruka/gocommon/storage" "github.com/nyaruka/mailroom/core/models" + "github.com/nyaruka/mailroom/core/queue" "github.com/nyaruka/mailroom/runtime" + "github.com/stretchr/testify/require" "github.com/gomodule/redigo/redis" "github.com/jmoiron/sqlx" @@ -104,9 +108,9 @@ func getRP() *redis.Pool { // returns a redis connection, Close() should be called on it when done func getRC() redis.Conn { conn, err := redis.Dial("tcp", "localhost:6379") - must(err) + noError(err) _, err = conn.Do("SELECT", 0) - must(err) + noError(err) return conn } @@ -154,12 +158,26 @@ func resetStorage() { var resetDataSQL = ` DELETE FROM notifications_notification; +DELETE FROM notifications_incident; DELETE FROM request_logs_httplog; DELETE FROM tickets_ticketevent; DELETE FROM tickets_ticket; +DELETE FROM triggers_trigger_contacts WHERE trigger_id >= 30000; +DELETE FROM triggers_trigger_groups WHERE trigger_id >= 30000; +DELETE FROM triggers_trigger WHERE id >= 30000; DELETE FROM channels_channelcount; DELETE FROM msgs_msg; +DELETE FROM flows_flowrun; +DELETE FROM flows_flowpathcount; +DELETE FROM flows_flownodecount; +DELETE FROM flows_flowruncount; +DELETE FROM flows_flowcategorycount; +DELETE FROM flows_flowsession; +DELETE FROM flows_flowrevision WHERE flow_id >= 30000; +DELETE FROM flows_flow WHERE id >= 30000; DELETE FROM campaigns_eventfire; +DELETE FROM campaigns_campaignevent WHERE id >= 30000; +DELETE FROM campaigns_campaign WHERE id >= 30000; DELETE FROM contacts_contactimportbatch; DELETE FROM contacts_contactimport; DELETE FROM contacts_contacturn WHERE id >= 30000; @@ -168,11 +186,16 @@ DELETE FROM contacts_contact WHERE id >= 30000; DELETE FROM contacts_contactgroupcount WHERE group_id >= 30000; DELETE FROM contacts_contactgroup WHERE id >= 30000; +ALTER SEQUENCE flows_flow_id_seq RESTART WITH 30000; ALTER SEQUENCE tickets_ticket_id_seq RESTART WITH 1; ALTER SEQUENCE msgs_msg_id_seq RESTART WITH 1; +ALTER SEQUENCE flows_flowrun_id_seq RESTART WITH 1; +ALTER SEQUENCE flows_flowsession_id_seq RESTART WITH 1; ALTER SEQUENCE contacts_contact_id_seq RESTART WITH 30000; ALTER SEQUENCE contacts_contacturn_id_seq RESTART WITH 30000; -ALTER SEQUENCE contacts_contactgroup_id_seq RESTART WITH 30000;` +ALTER SEQUENCE contacts_contactgroup_id_seq RESTART WITH 30000; +ALTER SEQUENCE campaigns_campaign_id_seq RESTART WITH 30000; +ALTER SEQUENCE campaigns_campaignevent_id_seq RESTART WITH 30000;` // removes contact data not in the test database dump. Note that this function can't // undo changes made to the contact data in the test database dump. @@ -195,8 +218,45 @@ func mustExec(command string, args ...string) { } } +// convenience way to call a func and panic if it errors, e.g. must(foo()) func must(err error) { if err != nil { panic(err) } } + +// if just checking an error is nil noError(err) reads better than must(err) +var noError = must + +func ReadFile(path string) []byte { + d, err := os.ReadFile(path) + noError(err) + return d +} + +func CurrentOrgTasks(t *testing.T, rp *redis.Pool) map[models.OrgID][]*queue.Task { + rc := rp.Get() + defer rc.Close() + + // get all active org queues + active, err := redis.Ints(rc.Do("ZRANGE", "batch:active", 0, -1)) + require.NoError(t, err) + + tasks := make(map[models.OrgID][]*queue.Task) + for _, orgID := range active { + orgTasksEncoded, err := redis.Strings(rc.Do("ZRANGE", fmt.Sprintf("batch:%d", orgID), 0, -1)) + require.NoError(t, err) + + orgTasks := make([]*queue.Task, len(orgTasksEncoded)) + + for i := range orgTasksEncoded { + task := &queue.Task{} + jsonx.MustUnmarshal([]byte(orgTasksEncoded[i]), task) + orgTasks[i] = task + } + + tasks[models.OrgID(orgID)] = orgTasks + } + + return tasks +} diff --git a/utils/cron/cron.go b/utils/cron/cron.go index 0b5b3d921..d673c1e62 100644 --- a/utils/cron/cron.go +++ b/utils/cron/cron.go @@ -5,20 +5,28 @@ import ( "time" "github.com/apex/log" - "github.com/gomodule/redigo/redis" - "github.com/nyaruka/mailroom/utils/locker" + "github.com/nyaruka/mailroom/runtime" + "github.com/nyaruka/redisx" "github.com/sirupsen/logrus" ) // Function is the function that will be called on our schedule -type Function func(lockName string, lockValue string) error +type Function func() error -// StartCron calls the passed in function every minute, making sure it acquires a +// Start calls the passed in function every interval, making sure it acquires a // lock so that only one process is running at once. Note that across processes // crons may be called more often than duration as there is no inter-process // coordination of cron fires. (this might be a worthy addition) -func StartCron(quit chan bool, rp *redis.Pool, name string, interval time.Duration, cronFunc Function) { - lockName := fmt.Sprintf("%s_lock", name) +func Start(quit chan bool, rt *runtime.Runtime, name string, interval time.Duration, allInstances bool, cronFunc Function) { + lockName := fmt.Sprintf("lock:%s_lock", name) // for historical reasons... + + // for jobs that run on all instances, the lock key is specific to this instance + if allInstances { + lockName = fmt.Sprintf("%s:%s", lockName, rt.Config.InstanceName) + } + + locker := redisx.NewLocker(lockName, time.Minute*5) + wait := time.Duration(0) lastFire := time.Now() @@ -27,7 +35,6 @@ func StartCron(quit chan bool, rp *redis.Pool, name string, interval time.Durati go func() { defer log.Info("cron exiting") - // we run expiration every minute on the minute for { select { case <-quit: @@ -35,10 +42,10 @@ func StartCron(quit chan bool, rp *redis.Pool, name string, interval time.Durati return case <-time.After(wait): - // try to insert our expiring lock to redis lastFire = time.Now() - lock, err := locker.GrabLock(rp, lockName, time.Minute*5, 0) + // try to get lock but don't retry - if lock is taken then task is still running or running on another instance + lock, err := locker.Grab(rt.RP, 0) if err != nil { break } @@ -56,14 +63,14 @@ func StartCron(quit chan bool, rp *redis.Pool, name string, interval time.Durati } // release our lock - err = locker.ReleaseLock(rp, lockName, lock) + err = locker.Release(rt.RP, lock) if err != nil { log.WithError(err).Error("error releasing lock") } } // calculate our next fire time - nextFire := nextFire(lastFire, interval) + nextFire := NextFire(lastFire, interval) wait = time.Until(nextFire) if wait < time.Duration(0) { wait = time.Duration(0) @@ -76,6 +83,7 @@ func StartCron(quit chan bool, rp *redis.Pool, name string, interval time.Durati // catching and logging panics func fireCron(cronFunc Function, lockName string, lockValue string) error { log := log.WithField("lockValue", lockValue).WithField("func", cronFunc) + defer func() { // catch any panics and recover panicLog := recover() @@ -84,11 +92,11 @@ func fireCron(cronFunc Function, lockName string, lockValue string) error { } }() - return cronFunc(lockName, lockValue) + return cronFunc() } -// nextFire returns the next time we should fire based on the passed in time and interval -func nextFire(last time.Time, interval time.Duration) time.Time { +// NextFire returns the next time we should fire based on the passed in time and interval +func NextFire(last time.Time, interval time.Duration) time.Time { if interval >= time.Second && interval < time.Minute { normalizedInterval := interval - ((time.Duration(last.Second()) * time.Second) % interval) return last.Add(normalizedInterval) diff --git a/utils/cron/cron_test.go b/utils/cron/cron_test.go index 2e9c95096..f0cb771f1 100644 --- a/utils/cron/cron_test.go +++ b/utils/cron/cron_test.go @@ -1,41 +1,119 @@ -package cron +package cron_test import ( - "sync" "testing" "time" "github.com/nyaruka/mailroom/testsuite" + "github.com/nyaruka/mailroom/utils/cron" "github.com/stretchr/testify/assert" ) func TestCron(t *testing.T) { - _, _, _, rp := testsuite.Get() + _, rt, _, _ := testsuite.Get() defer testsuite.Reset(testsuite.ResetRedis) - rc := rp.Get() - defer rc.Close() + align := func() { + untilNextSecond := time.Nanosecond * time.Duration(1_000_000_000-time.Now().Nanosecond()) // time until next second boundary + time.Sleep(untilNextSecond) // wait until after second boundary + } + + createCronFunc := func(running *bool, fired *int, delays map[int]time.Duration, defaultDelay time.Duration) cron.Function { + return func() error { + if *running { + assert.Fail(t, "more than 1 thread is trying to run our cron job") + } + + *running = true + delay := delays[*fired] + if delay == 0 { + delay = defaultDelay + } + time.Sleep(delay) + *fired++ + *running = false + return nil + } + } - mutex := sync.RWMutex{} fired := 0 quit := make(chan bool) + running := false - // our cron worker is just going to increment an int on every fire - increment := func(lockName string, lockValue string) error { - mutex.Lock() - fired++ - mutex.Unlock() - return nil - } + align() - StartCron(quit, rp, "test", time.Millisecond*100, increment) + // start a job that takes ~100 ms and runs every 250ms + cron.Start(quit, rt, "test1", time.Millisecond*250, false, createCronFunc(&running, &fired, map[int]time.Duration{}, time.Millisecond*100)) - // wait a bit, should only have fired three times (initial time + three timeouts) - time.Sleep(time.Millisecond * 320) + // wait a bit, should only have fired three times (initial time + three repeats) + time.Sleep(time.Millisecond * 875) // time for 3 delays between tasks plus half of another delay assert.Equal(t, 4, fired) + // tell the job to quit and check we don't see more fires + close(quit) + + time.Sleep(time.Millisecond * 500) + assert.Equal(t, 4, fired) + + fired = 0 + quit = make(chan bool) + running = false + + align() + + // simulate the job taking 400ms to run on the second fire, thus skipping the third fire + cron.Start(quit, rt, "test2", time.Millisecond*250, false, createCronFunc(&running, &fired, map[int]time.Duration{1: time.Millisecond * 400}, time.Millisecond*100)) + + time.Sleep(time.Millisecond * 875) + assert.Equal(t, 3, fired) + + close(quit) + + // simulate two different instances running the same cron + cfg1 := *rt.Config + cfg2 := *rt.Config + cfg1.InstanceName = "instance1" + cfg2.InstanceName = "instance2" + rt1 := *rt + rt1.Config = &cfg1 + rt2 := *rt + rt2.Config = &cfg2 + + fired1 := 0 + fired2 := 0 + quit = make(chan bool) + running = false + + align() + + cron.Start(quit, &rt1, "test3", time.Millisecond*250, false, createCronFunc(&running, &fired1, map[int]time.Duration{}, time.Millisecond*100)) + cron.Start(quit, &rt2, "test3", time.Millisecond*250, false, createCronFunc(&running, &fired2, map[int]time.Duration{}, time.Millisecond*100)) + + // same number of fires as if only a single instance was running it... + time.Sleep(time.Millisecond * 875) + assert.Equal(t, 4, fired1+fired2) // can't say which instances will run the 4 fires + + close(quit) + + fired1 = 0 + fired2 = 0 + quit = make(chan bool) + running1 := false + running2 := false + + align() + + // unless we start the cron with allInstances = true + cron.Start(quit, &rt1, "test4", time.Millisecond*250, true, createCronFunc(&running1, &fired1, map[int]time.Duration{}, time.Millisecond*100)) + cron.Start(quit, &rt2, "test4", time.Millisecond*250, true, createCronFunc(&running2, &fired2, map[int]time.Duration{}, time.Millisecond*100)) + + // now both instances fire 4 times + time.Sleep(time.Millisecond * 875) + assert.Equal(t, 4, fired1) + assert.Equal(t, 4, fired2) + close(quit) } @@ -43,7 +121,7 @@ func TestNextFire(t *testing.T) { tcs := []struct { last time.Time interval time.Duration - next time.Time + expected time.Time }{ {time.Date(2000, 1, 1, 1, 1, 4, 0, time.UTC), time.Minute, time.Date(2000, 1, 1, 1, 2, 1, 0, time.UTC)}, {time.Date(2000, 1, 1, 1, 1, 44, 0, time.UTC), time.Minute, time.Date(2000, 1, 1, 1, 2, 1, 0, time.UTC)}, @@ -53,7 +131,7 @@ func TestNextFire(t *testing.T) { } for _, tc := range tcs { - next := nextFire(tc.last, tc.interval) - assert.Equal(t, tc.next, next) + actual := cron.NextFire(tc.last, tc.interval) + assert.Equal(t, tc.expected, actual, "next fire mismatch for %s + %s", tc.last, tc.interval) } } diff --git a/utils/dbutil/errors.go b/utils/dbutil/errors.go deleted file mode 100644 index ebb316e48..000000000 --- a/utils/dbutil/errors.go +++ /dev/null @@ -1,55 +0,0 @@ -package dbutil - -import ( - "errors" - "fmt" - - "github.com/lib/pq" - "github.com/sirupsen/logrus" -) - -// IsUniqueViolation returns true if the given error is a violation of unique constraint -func IsUniqueViolation(err error) bool { - if pqErr, ok := err.(*pq.Error); ok { - return pqErr.Code.Name() == "unique_violation" - } - return false -} - -// QueryError is an error type for failed SQL queries -type QueryError struct { - cause error - message string - sql string - sqlArgs []interface{} -} - -func (e *QueryError) Error() string { - return e.message + ": " + e.cause.Error() -} - -func (e *QueryError) Unwrap() error { - return e.cause -} - -func (e *QueryError) Fields() logrus.Fields { - return logrus.Fields{ - "sql": fmt.Sprintf("%.1000s", e.sql), - "sql_args": e.sqlArgs, - } -} - -func NewQueryErrorf(cause error, sql string, sqlArgs []interface{}, message string, msgArgs ...interface{}) error { - return &QueryError{ - cause: cause, - message: fmt.Sprintf(message, msgArgs...), - sql: sql, - sqlArgs: sqlArgs, - } -} - -func AsQueryError(err error) *QueryError { - var qerr *QueryError - errors.As(err, &qerr) - return qerr -} diff --git a/utils/dbutil/errors_test.go b/utils/dbutil/errors_test.go deleted file mode 100644 index 712e13a4c..000000000 --- a/utils/dbutil/errors_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package dbutil_test - -import ( - "testing" - - "github.com/lib/pq" - "github.com/nyaruka/mailroom/utils/dbutil" - "github.com/sirupsen/logrus" - - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" -) - -func TestIsUniqueViolation(t *testing.T) { - var err error = &pq.Error{Code: pq.ErrorCode("23505")} - - assert.True(t, dbutil.IsUniqueViolation(err)) - assert.False(t, dbutil.IsUniqueViolation(errors.New("boom"))) -} - -func TestQueryError(t *testing.T) { - var err error = &pq.Error{Code: pq.ErrorCode("22025"), Message: "unsupported Unicode escape sequence"} - - qerr := dbutil.NewQueryErrorf(err, "SELECT * FROM foo WHERE id = $1", []interface{}{234}, "error selecting foo %d", 234) - assert.Error(t, qerr) - assert.Equal(t, `error selecting foo 234: pq: unsupported Unicode escape sequence`, qerr.Error()) - - // can unwrap to the original error - var pqerr *pq.Error - assert.True(t, errors.As(qerr, &pqerr)) - assert.Equal(t, err, pqerr) - - // can unwrap a wrapped error to find the first query error - wrapped := errors.Wrap(errors.Wrap(qerr, "error doing this"), "error doing that") - unwrapped := dbutil.AsQueryError(wrapped) - assert.Equal(t, qerr, unwrapped) - - // nil if error was never a query error - wrapped = errors.Wrap(errors.New("error doing this"), "error doing that") - assert.Nil(t, dbutil.AsQueryError(wrapped)) - - assert.Equal(t, logrus.Fields{"sql": "SELECT * FROM foo WHERE id = $1", "sql_args": []interface{}{234}}, unwrapped.Fields()) -} diff --git a/utils/dbutil/json.go b/utils/dbutil/json.go deleted file mode 100644 index f862373ce..000000000 --- a/utils/dbutil/json.go +++ /dev/null @@ -1,33 +0,0 @@ -package dbutil - -import ( - "encoding/json" - - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" - "gopkg.in/go-playground/validator.v9" -) - -var validate = validator.New() - -// ReadJSONRow reads a row which is JSON into a destination struct -func ReadJSONRow(rows *sqlx.Rows, destination interface{}) error { - var jsonBlob string - err := rows.Scan(&jsonBlob) - if err != nil { - return errors.Wrap(err, "error scanning row JSON") - } - - err = json.Unmarshal([]byte(jsonBlob), destination) - if err != nil { - return errors.Wrap(err, "error unmarshalling row JSON") - } - - // validate our final struct - err = validate.Struct(destination) - if err != nil { - return errors.Wrapf(err, "failed validation for JSON: %s", jsonBlob) - } - - return nil -} diff --git a/utils/dbutil/json_test.go b/utils/dbutil/json_test.go deleted file mode 100644 index 70901a8f4..000000000 --- a/utils/dbutil/json_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package dbutil_test - -import ( - "testing" - - "github.com/nyaruka/mailroom/testsuite" - "github.com/nyaruka/mailroom/testsuite/testdata" - "github.com/nyaruka/mailroom/utils/dbutil" - - "github.com/jmoiron/sqlx" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestReadJSONRow(t *testing.T) { - ctx, _, db, _ := testsuite.Get() - - type group struct { - UUID string `json:"uuid"` - Name string `json:"name"` - } - - queryRows := func(sql string, args ...interface{}) *sqlx.Rows { - rows, err := db.QueryxContext(ctx, sql, args...) - require.NoError(t, err) - require.True(t, rows.Next()) - return rows - } - - // if query returns valid JSON which can be unmarshaled into our struct, all good - rows := queryRows(`SELECT ROW_TO_JSON(r) FROM (SELECT g.uuid as uuid, g.name AS name FROM contacts_contactgroup g WHERE id = $1) r`, testdata.TestersGroup.ID) - - g := &group{} - err := dbutil.ReadJSONRow(rows, g) - assert.NoError(t, err) - assert.Equal(t, "5e9d8fab-5e7e-4f51-b533-261af5dea70d", g.UUID) - assert.Equal(t, "Testers", g.Name) - - // error if row value is not JSON - rows = queryRows(`SELECT id FROM contacts_contactgroup g WHERE id = $1`, testdata.TestersGroup.ID) - err = dbutil.ReadJSONRow(rows, g) - assert.EqualError(t, err, "error unmarshalling row JSON: json: cannot unmarshal number into Go value of type dbutil_test.group") - - // error if rows aren't ready to be scanned - e.g. next hasn't been called - rows, err = db.QueryxContext(ctx, `SELECT ROW_TO_JSON(r) FROM (SELECT g.uuid as uuid, g.name AS name FROM contacts_contactgroup g WHERE id = $1) r`, testdata.TestersGroup.ID) - require.NoError(t, err) - err = dbutil.ReadJSONRow(rows, g) - assert.EqualError(t, err, "error scanning row JSON: sql: Scan called without calling Next") -} diff --git a/utils/dbutil/query.go b/utils/dbutil/query.go deleted file mode 100644 index 4b2ade841..000000000 --- a/utils/dbutil/query.go +++ /dev/null @@ -1,132 +0,0 @@ -package dbutil - -import ( - "context" - "strings" - - "github.com/jmoiron/sqlx" - "github.com/pkg/errors" -) - -// Queryer is the DB/TX functionality needed for operations in this package -type Queryer interface { - Rebind(query string) string - QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) -} - -// BulkQuery runs the query as a bulk operation with the given structs -func BulkQuery(ctx context.Context, tx Queryer, query string, structs []interface{}) error { - // no structs, nothing to do - if len(structs) == 0 { - return nil - } - - // rewrite query as a bulk operation - bulkQuery, args, err := BulkSQL(tx, query, structs) - if err != nil { - return err - } - - rows, err := tx.QueryxContext(ctx, bulkQuery, args...) - if err != nil { - return NewQueryErrorf(err, bulkQuery, args, "error making bulk query") - } - defer rows.Close() - - // if have a returning clause, read them back and try to map them - if strings.Contains(strings.ToUpper(query), "RETURNING") { - for _, s := range structs { - if !rows.Next() { - return errors.Errorf("did not receive expected number of rows on insert") - } - - err = rows.StructScan(s) - if err != nil { - return errors.Wrap(err, "error scanning for insert id") - } - } - } - - // iterate our remaining rows - for rows.Next() { - } - - // check for any error - if rows.Err() != nil { - return errors.Wrapf(rows.Err(), "error in row cursor") - } - - return nil -} - -// BulkSQL takes a query which uses VALUES with struct bindings and rewrites it as a bulk operation. -// It returns the new SQL query and the args to pass to it. -func BulkSQL(tx Queryer, sql string, structs []interface{}) (string, []interface{}, error) { - if len(structs) == 0 { - return "", nil, errors.New("can't generate bulk sql with zero structs") - } - - // this will be our SQL placeholders for values in our final query, built dynamically - values := strings.Builder{} - values.Grow(7 * len(structs)) - - // this will be each of the arguments to match the positional values above - args := make([]interface{}, 0, len(structs)*5) - - // for each value we build a bound SQL statement, then extract the values clause - for i, value := range structs { - valueSQL, valueArgs, err := sqlx.Named(sql, value) - if err != nil { - return "", nil, errors.Wrapf(err, "error converting bulk insert args") - } - - args = append(args, valueArgs...) - argValues := extractValues(valueSQL) - if argValues == "" { - return "", nil, errors.Errorf("error extracting VALUES from sql: %s", valueSQL) - } - - // append to our global values, adding comma if necessary - values.WriteString(argValues) - if i+1 < len(structs) { - values.WriteString(",") - } - } - - valuesSQL := extractValues(sql) - if valuesSQL == "" { - return "", nil, errors.Errorf("error extracting VALUES from sql: %s", sql) - } - - return tx.Rebind(strings.Replace(sql, valuesSQL, values.String(), -1)), args, nil -} - -// extractValues is just a simple utility method that extracts the portion between `VALUE(` -// and `)` in the passed in string. (leaving VALUE but not the parentheses) -func extractValues(sql string) string { - startValues := strings.Index(sql, "VALUES(") - if startValues <= 0 { - return "" - } - - // find the matching end parentheses, we need to count balances here - openCount := 1 - endValues := -1 - for i, r := range sql[startValues+7:] { - if r == '(' { - openCount++ - } else if r == ')' { - openCount-- - if openCount == 0 { - endValues = i + startValues + 7 - break - } - } - } - - if endValues <= 0 { - return "" - } - - return sql[startValues+6 : endValues+1] -} diff --git a/utils/dbutil/query_test.go b/utils/dbutil/query_test.go deleted file mode 100644 index 3f50c9bf9..000000000 --- a/utils/dbutil/query_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package dbutil_test - -import ( - "testing" - - "github.com/nyaruka/mailroom/testsuite" - "github.com/nyaruka/mailroom/utils/dbutil" - - _ "github.com/lib/pq" - "github.com/stretchr/testify/assert" -) - -func TestBulkSQL(t *testing.T) { - _, _, db, _ := testsuite.Get() - - type contact struct { - ID int `db:"id"` - Name string `db:"name"` - } - - _, _, err := dbutil.BulkSQL(db, `UPDATE contact_contact SET name = :name WHERE id = :id`, []interface{}{contact{ID: 1, Name: "Bob"}}) - assert.EqualError(t, err, "error extracting VALUES from sql: UPDATE contact_contact SET name = ? WHERE id = ?") - - sql := `INSERT INTO contacts_contact (id, name) VALUES(:id, :name)` - - // try with zero structs - _, _, err = dbutil.BulkSQL(db, sql, []interface{}{}) - assert.EqualError(t, err, "can't generate bulk sql with zero structs") - - // try with one struct - query, args, err := dbutil.BulkSQL(db, sql, []interface{}{contact{ID: 1, Name: "Bob"}}) - assert.NoError(t, err) - assert.Equal(t, `INSERT INTO contacts_contact (id, name) VALUES($1, $2)`, query) - assert.Equal(t, []interface{}{1, "Bob"}, args) - - // try with multiple... - query, args, err = dbutil.BulkSQL(db, sql, []interface{}{contact{ID: 1, Name: "Bob"}, contact{ID: 2, Name: "Cathy"}, contact{ID: 3, Name: "George"}}) - assert.NoError(t, err) - assert.Equal(t, `INSERT INTO contacts_contact (id, name) VALUES($1, $2),($3, $4),($5, $6)`, query) - assert.Equal(t, []interface{}{1, "Bob", 2, "Cathy", 3, "George"}, args) -} - -func TestBulkQuery(t *testing.T) { - ctx, _, db, _ := testsuite.Get() - - defer testsuite.Reset(testsuite.ResetAll) - - db.MustExec(`CREATE TABLE foo (id serial NOT NULL PRIMARY KEY, name VARCHAR(3), age INT)`) - - type foo struct { - ID int `db:"id"` - Name string `db:"name"` - Age int `db:"age"` - } - - sql := `INSERT INTO foo (name, age) VALUES(:name, :age) RETURNING id` - - // noop with zero structs - err := dbutil.BulkQuery(ctx, db, sql, nil) - assert.NoError(t, err) - - // returned ids are scanned into structs - foo1 := &foo{Name: "Bob", Age: 64} - foo2 := &foo{Name: "Jon", Age: 34} - err = dbutil.BulkQuery(ctx, db, sql, []interface{}{foo1, foo2}) - assert.NoError(t, err) - assert.Equal(t, 1, foo1.ID) - assert.Equal(t, 2, foo2.ID) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM foo WHERE name = 'Bob' AND age = 64`).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM foo WHERE name = 'Jon' AND age = 34`).Returns(1) - - // returning ids is optional - foo3 := &foo{Name: "Jim", Age: 54} - err = dbutil.BulkQuery(ctx, db, `INSERT INTO foo (name, age) VALUES(:name, :age)`, []interface{}{foo3}) - assert.NoError(t, err) - assert.Equal(t, 0, foo3.ID) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM foo WHERE name = 'Jim' AND age = 54`).Returns(1) - - // try with a struct that is invalid - foo4 := &foo{Name: "Jonny", Age: 34} - err = dbutil.BulkQuery(ctx, db, `INSERT INTO foo (name, age) VALUES(:name, :age)`, []interface{}{foo4}) - assert.EqualError(t, err, "error making bulk query: pq: value too long for type character varying(3)") - assert.Equal(t, 0, foo4.ID) -} diff --git a/utils/marker/marker.go b/utils/marker/marker.go deleted file mode 100644 index 80fa47991..000000000 --- a/utils/marker/marker.go +++ /dev/null @@ -1,69 +0,0 @@ -package marker - -import ( - "fmt" - "time" - - "github.com/gomodule/redigo/redis" - "github.com/pkg/errors" -) - -const ( - keyPattern = "%s_%s" - oneDay = 60 * 60 * 24 -) - -var hasTask = redis.NewScript(3, - `-- KEYS: [TodayKey, YesterdayKey, TaskID] - local found = redis.call("sismember", KEYS[1], KEYS[3]) - if found == 1 then - return 1 - end - return redis.call("sismember", KEYS[2], KEYS[3]) -`) - -// HasTask returns whether the passed in taskID has already been marked for execution -func HasTask(rc redis.Conn, taskGroup string, taskID string) (bool, error) { - todayKey := fmt.Sprintf(keyPattern, taskGroup, time.Now().UTC().Format("2006_01_02")) - yesterdayKey := fmt.Sprintf(keyPattern, taskGroup, time.Now().Add(time.Hour*-24).UTC().Format("2006_01_02")) - found, err := redis.Bool(hasTask.Do(rc, todayKey, yesterdayKey, taskID)) - if err != nil { - return false, errors.Wrapf(err, "error checking for task: %s for group: %s", taskID, taskGroup) - } - return found, nil -} - -// AddTask marks the passed in task -func AddTask(rc redis.Conn, taskGroup string, taskID string) error { - dateKey := fmt.Sprintf(keyPattern, taskGroup, time.Now().UTC().Format("2006_01_02")) - rc.Send("sadd", dateKey, taskID) - rc.Send("expire", dateKey, oneDay) - _, err := rc.Do("") - if err != nil { - return errors.Wrapf(err, "error adding task: %s to redis set for group: %s", taskID, taskGroup) - } - return nil -} - -// RemoveTask removes the task with the passed in id from our lock -func RemoveTask(rc redis.Conn, taskGroup string, taskID string) error { - todayKey := fmt.Sprintf(keyPattern, taskGroup, time.Now().UTC().Format("2006_01_02")) - yesterdayKey := fmt.Sprintf(keyPattern, taskGroup, time.Now().Add(time.Hour*-24).UTC().Format("2006_01_02")) - rc.Send("srem", todayKey, taskID) - rc.Send("srem", yesterdayKey, taskID) - _, err := rc.Do("") - if err != nil { - return errors.Wrapf(err, "error removing task: %s from redis set for group: %s", taskID, taskGroup) - } - return nil -} - -// ClearTasks removes all tasks for the passed in group (mostly useful in unit testing) -func ClearTasks(rc redis.Conn, taskGroup string) error { - todayKey := fmt.Sprintf(keyPattern, taskGroup, time.Now().UTC().Format("2006_01_02")) - yesterdayKey := fmt.Sprintf(keyPattern, taskGroup, time.Now().Add(time.Hour*-24).UTC().Format("2006_01_02")) - rc.Send("del", todayKey) - rc.Send("del", yesterdayKey) - _, err := rc.Do("") - return err -} diff --git a/utils/marker/marker_test.go b/utils/marker/marker_test.go deleted file mode 100644 index 7e2e463ad..000000000 --- a/utils/marker/marker_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package marker_test - -import ( - "testing" - - "github.com/nyaruka/mailroom/testsuite" - "github.com/nyaruka/mailroom/utils/marker" - - "github.com/stretchr/testify/assert" -) - -func TestMarker(t *testing.T) { - _, _, _, rp := testsuite.Get() - - defer testsuite.Reset(testsuite.ResetRedis) - - rc := rp.Get() - defer rc.Close() - - tcs := []struct { - Group string - TaskID string - Action string - }{ - {"1", "1", "remove"}, - {"2", "1", "remove"}, - {"1", "2", "remove"}, - {"1", "1", "absent"}, - {"1", "1", "add"}, - {"1", "1", "present"}, - {"2", "1", "absent"}, - {"1", "2", "absent"}, - {"1", "1", "remove"}, - {"1", "1", "absent"}, - } - - for i, tc := range tcs { - if tc.Action == "absent" { - present, err := marker.HasTask(rc, tc.Group, tc.TaskID) - assert.NoError(t, err) - assert.False(t, present, "%d: %s:%s should be absent", i, tc.Group, tc.TaskID) - } else if tc.Action == "present" { - present, err := marker.HasTask(rc, tc.Group, tc.TaskID) - assert.NoError(t, err) - assert.True(t, present, "%d: %s:%s should be present", i, tc.Group, tc.TaskID) - } else if tc.Action == "add" { - err := marker.AddTask(rc, tc.Group, tc.TaskID) - assert.NoError(t, err) - } else if tc.Action == "remove" { - err := marker.RemoveTask(rc, tc.Group, tc.TaskID) - assert.NoError(t, err) - } - } -} diff --git a/web/contact/contact_test.go b/web/contact/contact_test.go index 70711f5d1..c7cb5b762 100644 --- a/web/contact/contact_test.go +++ b/web/contact/contact_test.go @@ -4,7 +4,6 @@ import ( "testing" "time" - "github.com/nyaruka/gocommon/uuids" _ "github.com/nyaruka/mailroom/core/handlers" "github.com/nyaruka/mailroom/core/models" "github.com/nyaruka/mailroom/testsuite" @@ -37,11 +36,7 @@ func TestModifyContacts(t *testing.T) { db.MustExec(`UPDATE contacts_contactgroup SET query = 'age > 18' WHERE id = $1`, testdata.DoctorsGroup.ID) // insert an event on our campaign that is based on created on - db.MustExec( - `INSERT INTO campaigns_campaignevent(is_active, created_on, modified_on, uuid, "offset", unit, event_type, delivery_hour, - campaign_id, created_by_id, modified_by_id, flow_id, relative_to_id, start_mode) - VALUES(TRUE, NOW(), NOW(), $1, 1000, 'W', 'F', -1, $2, 1, 1, $3, $4, 'I')`, - uuids.New(), testdata.RemindersCampaign.ID, testdata.Favorites.ID, testdata.CreatedOnField.ID) + testdata.InsertCampaignFlowEvent(db, testdata.RemindersCampaign, testdata.Favorites, testdata.CreatedOnField, 1000, "W") // for simpler tests we clear out cathy's fields and groups to start db.MustExec(`UPDATE contacts_contact SET fields = NULL WHERE id = $1`, testdata.Cathy.ID) diff --git a/web/contact/search.go b/web/contact/search.go index 96b8c12f0..8a31b935d 100644 --- a/web/contact/search.go +++ b/web/contact/search.go @@ -59,10 +59,6 @@ type searchResponse struct { Offset int `json:"offset"` Sort string `json:"sort"` Metadata *contactql.Inspection `json:"metadata,omitempty"` - - // deprecated - Fields []string `json:"fields"` - AllowAsGroup bool `json:"allow_as_group"` } // handles a contact search request @@ -97,29 +93,20 @@ func handleSearch(ctx context.Context, rt *runtime.Runtime, r *http.Request) (in // normalize and inspect the query normalized := "" var metadata *contactql.Inspection - allowAsGroup := false - fields := make([]string, 0) if parsed != nil { normalized = parsed.String() metadata = contactql.Inspect(parsed) - fields = append(fields, metadata.Attributes...) - for _, f := range metadata.Fields { - fields = append(fields, f.Key) - } - allowAsGroup = metadata.AllowAsGroup } // build our response response := &searchResponse{ - Query: normalized, - ContactIDs: hits, - Total: total, - Offset: request.Offset, - Sort: request.Sort, - Metadata: metadata, - Fields: fields, - AllowAsGroup: allowAsGroup, + Query: normalized, + ContactIDs: hits, + Total: total, + Offset: request.Offset, + Sort: request.Sort, + Metadata: metadata, } return response, http.StatusOK, nil diff --git a/web/contact/search_test.go b/web/contact/search_test.go index c99674f54..f98a41fe5 100644 --- a/web/contact/search_test.go +++ b/web/contact/search_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/test" _ "github.com/nyaruka/mailroom/core/handlers" "github.com/nyaruka/mailroom/core/models" @@ -75,57 +76,66 @@ func TestSearch(t *testing.T) { }`, testdata.Cathy.ID) tcs := []struct { - URL string - Method string - Body string - ESResponse string - ExpectedStatus int - ExpectedError string - ExpectedHits []models.ContactID - ExpectedQuery string - ExpectedFields []string - ExpectedESRequest string + method string + url string + body string + esResponse string + expectedStatus int + expectedError string + expectedHits []models.ContactID + expectedQuery string + expectedAttributes []string + expectedFields []*assets.FieldReference + expectedSchemes []string + expectedAllowAsGroup bool + expectedESRequest string }{ { - Method: "GET", - URL: "/mr/contact/search", - ExpectedStatus: 405, - ExpectedError: "illegal method: GET", + method: "GET", + url: "/mr/contact/search", + expectedStatus: 405, + expectedError: "illegal method: GET", }, { - Method: "POST", - URL: "/mr/contact/search", - Body: fmt.Sprintf(`{"org_id": 1, "query": "birthday = tomorrow", "group_uuid": "%s"}`, testdata.AllContactsGroup.UUID), - ExpectedStatus: 400, - ExpectedError: "can't resolve 'birthday' to attribute, scheme or field", + method: "POST", + url: "/mr/contact/search", + body: fmt.Sprintf(`{"org_id": 1, "query": "birthday = tomorrow", "group_uuid": "%s"}`, testdata.AllContactsGroup.UUID), + expectedStatus: 400, + expectedError: "can't resolve 'birthday' to attribute, scheme or field", }, { - Method: "POST", - URL: "/mr/contact/search", - Body: fmt.Sprintf(`{"org_id": 1, "query": "age > tomorrow", "group_uuid": "%s"}`, testdata.AllContactsGroup.UUID), - ExpectedStatus: 400, - ExpectedError: "can't convert 'tomorrow' to a number", + method: "POST", + url: "/mr/contact/search", + body: fmt.Sprintf(`{"org_id": 1, "query": "age > tomorrow", "group_uuid": "%s"}`, testdata.AllContactsGroup.UUID), + expectedStatus: 400, + expectedError: "can't convert 'tomorrow' to a number", }, { - Method: "POST", - URL: "/mr/contact/search", - Body: fmt.Sprintf(`{"org_id": 1, "query": "Cathy", "group_uuid": "%s"}`, testdata.AllContactsGroup.UUID), - ESResponse: singleESResponse, - ExpectedStatus: 200, - ExpectedHits: []models.ContactID{testdata.Cathy.ID}, - ExpectedQuery: `name ~ "Cathy"`, - ExpectedFields: []string{"name"}, + method: "POST", + url: "/mr/contact/search", + body: fmt.Sprintf(`{"org_id": 1, "query": "Cathy", "group_uuid": "%s"}`, testdata.AllContactsGroup.UUID), + esResponse: singleESResponse, + expectedStatus: 200, + expectedHits: []models.ContactID{testdata.Cathy.ID}, + expectedQuery: `name ~ "Cathy"`, + expectedAttributes: []string{"name"}, + expectedFields: []*assets.FieldReference{}, + expectedSchemes: []string{}, + expectedAllowAsGroup: true, }, { - Method: "POST", - URL: "/mr/contact/search", - Body: fmt.Sprintf(`{"org_id": 1, "query": "Cathy", "group_uuid": "%s", "exclude_ids": [%d, %d]}`, testdata.AllContactsGroup.UUID, testdata.Bob.ID, testdata.George.ID), - ESResponse: singleESResponse, - ExpectedStatus: 200, - ExpectedHits: []models.ContactID{testdata.Cathy.ID}, - ExpectedQuery: `name ~ "Cathy"`, - ExpectedFields: []string{"name"}, - ExpectedESRequest: `{ + method: "POST", + url: "/mr/contact/search", + body: fmt.Sprintf(`{"org_id": 1, "query": "Cathy", "group_uuid": "%s", "exclude_ids": [%d, %d]}`, testdata.AllContactsGroup.UUID, testdata.Bob.ID, testdata.George.ID), + esResponse: singleESResponse, + expectedStatus: 200, + expectedHits: []models.ContactID{testdata.Cathy.ID}, + expectedQuery: `name ~ "Cathy"`, + expectedAttributes: []string{"name"}, + expectedFields: []*assets.FieldReference{}, + expectedSchemes: []string{}, + expectedAllowAsGroup: true, + expectedESRequest: `{ "_source": false, "from": 0, "query": { @@ -176,42 +186,51 @@ func TestSearch(t *testing.T) { }`, }, { - Method: "POST", - URL: "/mr/contact/search", - Body: fmt.Sprintf(`{"org_id": 1, "query": "AGE = 10 and gender = M", "group_uuid": "%s"}`, testdata.AllContactsGroup.UUID), - ESResponse: singleESResponse, - ExpectedStatus: 200, - ExpectedHits: []models.ContactID{testdata.Cathy.ID}, - ExpectedQuery: `age = 10 AND gender = "M"`, - ExpectedFields: []string{"age", "gender"}, + method: "POST", + url: "/mr/contact/search", + body: fmt.Sprintf(`{"org_id": 1, "query": "AGE = 10 and gender = M", "group_uuid": "%s"}`, testdata.AllContactsGroup.UUID), + esResponse: singleESResponse, + expectedStatus: 200, + expectedHits: []models.ContactID{testdata.Cathy.ID}, + expectedQuery: `age = 10 AND gender = "M"`, + expectedAttributes: []string{}, + expectedFields: []*assets.FieldReference{ + assets.NewFieldReference("age", "Age"), + assets.NewFieldReference("gender", "Gender"), + }, + expectedSchemes: []string{}, + expectedAllowAsGroup: true, }, { - Method: "POST", - URL: "/mr/contact/search", - Body: fmt.Sprintf(`{"org_id": 1, "query": "", "group_uuid": "%s"}`, testdata.AllContactsGroup.UUID), - ESResponse: singleESResponse, - ExpectedStatus: 200, - ExpectedHits: []models.ContactID{testdata.Cathy.ID}, - ExpectedQuery: ``, - ExpectedFields: []string{}, + method: "POST", + url: "/mr/contact/search", + body: fmt.Sprintf(`{"org_id": 1, "query": "", "group_uuid": "%s"}`, testdata.AllContactsGroup.UUID), + esResponse: singleESResponse, + expectedStatus: 200, + expectedHits: []models.ContactID{testdata.Cathy.ID}, + expectedQuery: ``, + expectedAttributes: []string{}, + expectedFields: []*assets.FieldReference{}, + expectedSchemes: []string{}, + expectedAllowAsGroup: true, }, } for i, tc := range tcs { var body io.Reader - es.NextResponse = tc.ESResponse + es.NextResponse = tc.esResponse - if tc.Body != "" { - body = bytes.NewReader([]byte(tc.Body)) + if tc.body != "" { + body = bytes.NewReader([]byte(tc.body)) } - req, err := http.NewRequest(tc.Method, "http://localhost:8090"+tc.URL, body) + req, err := http.NewRequest(tc.method, "http://localhost:8090"+tc.url, body) assert.NoError(t, err, "%d: error creating request", i) resp, err := http.DefaultClient.Do(req) assert.NoError(t, err, "%d: error making request", i) - assert.Equal(t, tc.ExpectedStatus, resp.StatusCode, "%d: unexpected status", i) + assert.Equal(t, tc.expectedStatus, resp.StatusCode, "%d: unexpected status", i) content, err := io.ReadAll(resp.Body) assert.NoError(t, err, "%d: error reading body", i) @@ -221,18 +240,24 @@ func TestSearch(t *testing.T) { r := &searchResponse{} err = json.Unmarshal(content, r) assert.NoError(t, err) - assert.Equal(t, tc.ExpectedHits, r.ContactIDs) - assert.Equal(t, tc.ExpectedQuery, r.Query) - assert.Equal(t, tc.ExpectedFields, r.Fields) + assert.Equal(t, tc.expectedHits, r.ContactIDs) + assert.Equal(t, tc.expectedQuery, r.Query) - if tc.ExpectedESRequest != "" { - test.AssertEqualJSON(t, []byte(tc.ExpectedESRequest), []byte(es.LastBody), "elastic request mismatch") + if len(tc.expectedAttributes) > 0 || len(tc.expectedFields) > 0 || len(tc.expectedSchemes) > 0 { + assert.Equal(t, tc.expectedAttributes, r.Metadata.Attributes) + assert.Equal(t, tc.expectedFields, r.Metadata.Fields) + assert.Equal(t, tc.expectedSchemes, r.Metadata.Schemes) + assert.Equal(t, tc.expectedAllowAsGroup, r.Metadata.AllowAsGroup) + } + + if tc.expectedESRequest != "" { + test.AssertEqualJSON(t, []byte(tc.expectedESRequest), []byte(es.LastBody), "elastic request mismatch") } } else { r := &web.ErrorResponse{} err = json.Unmarshal(content, r) assert.NoError(t, err) - assert.Equal(t, tc.ExpectedError, r.Error) + assert.Equal(t, tc.expectedError, r.Error) } } } diff --git a/web/ivr/ivr_test.go b/web/ivr/ivr_test.go index 5c0f20621..cd61c262d 100644 --- a/web/ivr/ivr_test.go +++ b/web/ivr/ivr_test.go @@ -11,6 +11,7 @@ import ( "sync" "testing" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/jsonx" "github.com/nyaruka/goflow/test" "github.com/nyaruka/mailroom/core/models" @@ -115,11 +116,11 @@ func TestTwilioIVR(t *testing.T) { require.NoError(t, err) // check our 3 contacts have 3 wired calls - testsuite.AssertQuery(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, + assertdb.Query(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, testdata.Cathy.ID, models.ConnectionStatusWired, "Call1").Returns(1) - testsuite.AssertQuery(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, + assertdb.Query(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, testdata.Bob.ID, models.ConnectionStatusWired, "Call2").Returns(1) - testsuite.AssertQuery(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, + assertdb.Query(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, testdata.George.ID, models.ConnectionStatusWired, "Call3").Returns(1) tcs := []struct { @@ -319,23 +320,23 @@ func TestTwilioIVR(t *testing.T) { } for connExtID, expStatus := range tc.expectedConnStatus { - testsuite.AssertQuery(t, db, `SELECT status FROM channels_channelconnection WHERE external_id = $1`, connExtID). + assertdb.Query(t, db, `SELECT status FROM channels_channelconnection WHERE external_id = $1`, connExtID). Columns(map[string]interface{}{"status": expStatus}, "status mismatch for connection '%s' in test '%s'", connExtID, tc.label) } } // check our final state of sessions, runs, msgs, connections - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowsession WHERE contact_id = $1 AND status = 'C'`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE contact_id = $1 AND status = 'C'`, testdata.Cathy.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowrun WHERE contact_id = $1 AND is_active = FALSE`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE contact_id = $1 AND status = 'C'`, testdata.Cathy.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND msg_type = 'V' AND connection_id = 1 AND status = 'W' AND direction = 'O'`, testdata.Cathy.ID).Returns(8) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND msg_type = 'V' AND status = 'W' AND direction = 'O'`, testdata.Cathy.ID).Returns(8) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND msg_type = 'V' AND connection_id = 1 AND status = 'H' AND direction = 'I'`, testdata.Cathy.ID).Returns(5) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND msg_type = 'V' AND status = 'H' AND direction = 'I'`, testdata.Cathy.ID).Returns(5) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM channels_channellog WHERE connection_id = 1 AND channel_id IS NOT NULL`).Returns(9) + assertdb.Query(t, db, `SELECT count(*) FROM channels_channellog WHERE connection_id = 1 AND channel_id IS NOT NULL`).Returns(9) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND msg_type = 'V' AND connection_id = 2 + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND msg_type = 'V' AND ((status = 'H' AND direction = 'I') OR (status = 'W' AND direction = 'O'))`, testdata.Bob.ID).Returns(2) } @@ -421,10 +422,10 @@ func TestVonageIVR(t *testing.T) { err = ivr_tasks.HandleFlowStartBatch(ctx, rt, batch) assert.NoError(t, err) - testsuite.AssertQuery(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, + assertdb.Query(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, testdata.Cathy.ID, models.ConnectionStatusWired, "Call1").Returns(1) - testsuite.AssertQuery(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, + assertdb.Query(t, db, `SELECT COUNT(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = $2 AND external_id = $3`, testdata.George.ID, models.ConnectionStatusWired, "Call2").Returns(1) tcs := []struct { @@ -621,24 +622,21 @@ func TestVonageIVR(t *testing.T) { } // check our final state of sessions, runs, msgs, connections - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowsession WHERE contact_id = $1 AND status = 'C'`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowsession WHERE contact_id = $1 AND status = 'C'`, testdata.Cathy.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM flows_flowrun WHERE contact_id = $1 AND is_active = FALSE`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM flows_flowrun WHERE contact_id = $1 AND is_active = FALSE`, testdata.Cathy.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = 'D' AND duration = 50`, testdata.Cathy.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM channels_channelconnection WHERE contact_id = $1 AND status = 'D' AND duration = 50`, testdata.Cathy.ID).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND msg_type = 'V' - AND connection_id = 1 AND status = 'W' AND direction = 'O'`, testdata.Cathy.ID).Returns(9) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND msg_type = 'V' AND status = 'W' AND direction = 'O'`, testdata.Cathy.ID).Returns(9) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM channels_channelconnection WHERE status = 'F' AND direction = 'I'`).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM channels_channelconnection WHERE status = 'F' AND direction = 'I'`).Returns(1) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND msg_type = 'V' - AND connection_id = 1 AND status = 'H' AND direction = 'I'`, testdata.Cathy.ID).Returns(5) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND msg_type = 'V' AND status = 'H' AND direction = 'I'`, testdata.Cathy.ID).Returns(5) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM channels_channellog WHERE connection_id = 1 AND channel_id IS NOT NULL`).Returns(10) + assertdb.Query(t, db, `SELECT count(*) FROM channels_channellog WHERE connection_id = 1 AND channel_id IS NOT NULL`).Returns(10) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND msg_type = 'V' - AND connection_id = 2 AND ((status = 'H' AND direction = 'I') OR (status = 'W' AND direction = 'O'))`, testdata.George.ID).Returns(3) + assertdb.Query(t, db, `SELECT count(*) FROM msgs_msg WHERE contact_id = $1 AND msg_type = 'V' AND ((status = 'H' AND direction = 'I') OR (status = 'W' AND direction = 'O'))`, testdata.George.ID).Returns(3) - testsuite.AssertQuery(t, db, `SELECT count(*) FROM channels_channelconnection WHERE status = 'D' AND contact_id = $1`, testdata.George.ID).Returns(1) + assertdb.Query(t, db, `SELECT count(*) FROM channels_channelconnection WHERE status = 'D' AND contact_id = $1`, testdata.George.ID).Returns(1) } diff --git a/web/msg/msg.go b/web/msg/msg.go index 73498d8a5..038233543 100644 --- a/web/msg/msg.go +++ b/web/msg/msg.go @@ -43,7 +43,7 @@ func handleResend(ctx context.Context, rt *runtime.Runtime, r *http.Request) (in return nil, http.StatusInternalServerError, errors.Wrapf(err, "unable to load org assets") } - msgs, err := models.LoadMessages(ctx, rt.DB, request.OrgID, models.DirectionOut, request.MsgIDs) + msgs, err := models.GetMessagesByID(ctx, rt.DB, request.OrgID, models.DirectionOut, request.MsgIDs) if err != nil { return nil, http.StatusInternalServerError, errors.Wrap(err, "error loading messages to resend") } diff --git a/web/msg/msg_test.go b/web/msg/msg_test.go index fc18a1a56..b136293d7 100644 --- a/web/msg/msg_test.go +++ b/web/msg/msg_test.go @@ -16,8 +16,8 @@ func TestServer(t *testing.T) { defer testsuite.Reset(testsuite.ResetData) cathyIn := testdata.InsertIncomingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "hello", models.MsgStatusHandled) - cathyOut := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "how can we help", nil, models.MsgStatusSent) - bobOut := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.VonageChannel, testdata.Bob, "this failed", nil, models.MsgStatusFailed) + cathyOut := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.TwilioChannel, testdata.Cathy, "how can we help", nil, models.MsgStatusSent, false) + bobOut := testdata.InsertOutgoingMsg(db, testdata.Org1, testdata.VonageChannel, testdata.Bob, "this failed", nil, models.MsgStatusFailed, false) web.RunWebTests(t, ctx, rt, "testdata/resend.json", map[string]string{ "cathy_msgin_id": fmt.Sprintf("%d", cathyIn.ID()), diff --git a/web/po/po.go b/web/po/po.go index 3f67d5c9c..9157089ca 100644 --- a/web/po/po.go +++ b/web/po/po.go @@ -27,15 +27,13 @@ func init() { // { // "org_id": 123, // "flow_ids": [123, 354, 456], -// "language": "spa", -// "exclude_arguments": true +// "language": "spa" // } // type exportRequest struct { - OrgID models.OrgID `json:"org_id" validate:"required"` - FlowIDs []models.FlowID `json:"flow_ids" validate:"required"` - Language envs.Language `json:"language" validate:"omitempty,language"` - ExcludeArguments bool `json:"exclude_arguments"` + OrgID models.OrgID `json:"org_id" validate:"required"` + FlowIDs []models.FlowID `json:"flow_ids" validate:"required"` + Language envs.Language `json:"language" validate:"omitempty,language"` } func handleExport(ctx context.Context, rt *runtime.Runtime, r *http.Request, rawW http.ResponseWriter) error { @@ -49,12 +47,8 @@ func handleExport(ctx context.Context, rt *runtime.Runtime, r *http.Request, raw return err } - var excludeProperties []string - if request.ExcludeArguments { - excludeProperties = []string{"arguments"} - } - - po, err := translation.ExtractFromFlows("Generated by mailroom", request.Language, excludeProperties, flows...) + // extract everything the engine considers localizable except router arguments + po, err := translation.ExtractFromFlows("Generated by mailroom", request.Language, []string{"arguments"}, flows...) if err != nil { return errors.Wrapf(err, "unable to extract PO from flows") } diff --git a/web/po/testdata/export.json b/web/po/testdata/export.json index c933b1884..3336155e5 100644 --- a/web/po/testdata/export.json +++ b/web/po/testdata/export.json @@ -35,21 +35,5 @@ }, "status": 200, "response_file": "testdata/multiple_flows.es.po" - }, - { - "label": "export Spanish PO from flow exluding test arguments", - "method": "POST", - "path": "/mr/po/export", - "body": { - "org_id": 1, - "flow_ids": [ - 10000, - 10001 - ], - "language": "spa", - "exclude_arguments": true - }, - "status": 200, - "response_file": "testdata/multiple_flows_noargs.es.po" } ] \ No newline at end of file diff --git a/web/po/testdata/favorites.po b/web/po/testdata/favorites.po index 25ad7f8b6..e518bf326 100644 --- a/web/po/testdata/favorites.po +++ b/web/po/testdata/favorites.po @@ -14,12 +14,10 @@ msgstr "" msgid "All Responses" msgstr "" -#: Favorites/8d2e259c-bc3c-464f-8c15-985bc736e212/arguments:0 #: Favorites/baf07ebb-8a2a-4e63-aa08-d19aa408cd45/name:0 msgid "Blue" msgstr "" -#: Favorites/3e2dcf45-ffc0-4197-b5ab-25ed974ea612/arguments:0 #: Favorites/6e367c0c-65ab-479a-82e3-c597d8e35eef/name:0 msgid "Cyan" msgstr "" @@ -28,7 +26,6 @@ msgstr "" msgid "Good choice, I like @results.color.category_localized too! What is your favorite beer?" msgstr "" -#: Favorites/34a421ac-34cb-49d8-a2a5-534f52c60851/arguments:0 #: Favorites/c102acfc-8cc5-41fa-89ed-41cbfa362ba6/name:0 msgid "Green" msgstr "" @@ -45,15 +42,10 @@ msgstr "" msgid "Mmmmm... delicious @results.beer.category_localized. If only they made @(lower(results.color)) @results.beer.category_localized! Lastly, what is your name?" msgstr "" -#: Favorites/a03dceb1-7ac1-491d-93ef-23d3e099633b/arguments:0 #: Favorites/b9d718d3-b5e0-4d26-998e-2da31b24f2f9/name:0 msgid "Mutzig" msgstr "" -#: Favorites/3b400f91-db69-42b9-9fe2-24ad556b067a/arguments:0 -msgid "Navy" -msgstr "" - #: Favorites/7624633a-01a9-48f0-abca-957e7290df0a/name:0 msgid "No Response" msgstr "" @@ -63,18 +55,15 @@ msgstr "" msgid "Other" msgstr "" -#: Favorites/58119801-ed31-4538-888d-23779a01707f/arguments:0 #: Favorites/f1ca9ac8-d0aa-4758-a969-195be7330267/name:0 msgid "Primus" msgstr "" #: Favorites/58284598-805a-4740-8966-dcb09e3b670a/name:0 -#: Favorites/b0c29972-6fd4-485e-83c2-057a3f7a04da/arguments:0 msgid "Red" msgstr "" #: Favorites/52d7a9ab-52b7-4e82-ba7f-672fb8d6ec91/name:0 -#: Favorites/ada3d96a-a1a2-41eb-aac7-febdb98a9b4c/arguments:0 msgid "Skol" msgstr "" @@ -86,7 +75,6 @@ msgstr "" msgid "Thanks @results.name, we are all done!" msgstr "" -#: Favorites/2ba89eb6-6981-4c0d-a19d-3cf1fde52a43/arguments:0 #: Favorites/dbc3b9d2-e6ce-4ebe-9552-8ddce482c1d1/name:0 msgid "Turbo King" msgstr "" diff --git a/web/po/testdata/multiple_flows.es.po b/web/po/testdata/multiple_flows.es.po index 5a91e2e83..6ed3ab183 100644 --- a/web/po/testdata/multiple_flows.es.po +++ b/web/po/testdata/multiple_flows.es.po @@ -10,29 +10,19 @@ msgstr "" "Language-3: spa\n" "Source-Flows: 9de3663f-c5c5-4c92-9f45-ecbc09abcc85; 5890fe3a-f204-4661-b74d-025be4ee019c\n" -#: Pick+a+Number/b634f07f-7b2d-47bd-8795-051e56cf2609/arguments:0 -msgid "1" -msgstr "" - #: Pick+a+Number/f90c9734-3e58-4c07-96cc-315266c8ecfd/name:0 msgid "1-10" msgstr "" -#: Pick+a+Number/b634f07f-7b2d-47bd-8795-051e56cf2609/arguments:1 -msgid "10" -msgstr "" - #: Favorites/a602e75e-0814-4034-bb95-770906ddfe34/name:0 #: Pick+a+Number/ee9c1a1d-3426-4f07-83c8-dc3c1949fe6c/name:0 msgid "All Responses" msgstr "" -#: Favorites/8d2e259c-bc3c-464f-8c15-985bc736e212/arguments:0 #: Favorites/baf07ebb-8a2a-4e63-aa08-d19aa408cd45/name:0 msgid "Blue" msgstr "" -#: Favorites/3e2dcf45-ffc0-4197-b5ab-25ed974ea612/arguments:0 #: Favorites/6e367c0c-65ab-479a-82e3-c597d8e35eef/name:0 msgid "Cyan" msgstr "" @@ -41,7 +31,6 @@ msgstr "" msgid "Good choice, I like @results.color.category_localized too! What is your favorite beer?" msgstr "" -#: Favorites/34a421ac-34cb-49d8-a2a5-534f52c60851/arguments:0 #: Favorites/c102acfc-8cc5-41fa-89ed-41cbfa362ba6/name:0 msgid "Green" msgstr "" @@ -58,15 +47,10 @@ msgstr "" msgid "Mmmmm... delicious @results.beer.category_localized. If only they made @(lower(results.color)) @results.beer.category_localized! Lastly, what is your name?" msgstr "" -#: Favorites/a03dceb1-7ac1-491d-93ef-23d3e099633b/arguments:0 #: Favorites/b9d718d3-b5e0-4d26-998e-2da31b24f2f9/name:0 msgid "Mutzig" msgstr "" -#: Favorites/3b400f91-db69-42b9-9fe2-24ad556b067a/arguments:0 -msgid "Navy" -msgstr "" - #: Favorites/7624633a-01a9-48f0-abca-957e7290df0a/name:0 msgid "No Response" msgstr "" @@ -81,18 +65,15 @@ msgstr "" msgid "Pick a number between 1-10." msgstr "" -#: Favorites/58119801-ed31-4538-888d-23779a01707f/arguments:0 #: Favorites/f1ca9ac8-d0aa-4758-a969-195be7330267/name:0 msgid "Primus" msgstr "" #: Favorites/58284598-805a-4740-8966-dcb09e3b670a/name:0 -#: Favorites/b0c29972-6fd4-485e-83c2-057a3f7a04da/arguments:0 msgid "Red" msgstr "" #: Favorites/52d7a9ab-52b7-4e82-ba7f-672fb8d6ec91/name:0 -#: Favorites/ada3d96a-a1a2-41eb-aac7-febdb98a9b4c/arguments:0 msgid "Skol" msgstr "" @@ -104,7 +85,6 @@ msgstr "" msgid "Thanks @results.name, we are all done!" msgstr "" -#: Favorites/2ba89eb6-6981-4c0d-a19d-3cf1fde52a43/arguments:0 #: Favorites/dbc3b9d2-e6ce-4ebe-9552-8ddce482c1d1/name:0 msgid "Turbo King" msgstr "" diff --git a/web/po/testdata/multiple_flows_noargs.es.po b/web/po/testdata/multiple_flows_noargs.es.po deleted file mode 100644 index 6ed3ab183..000000000 --- a/web/po/testdata/multiple_flows_noargs.es.po +++ /dev/null @@ -1,99 +0,0 @@ -# Generated by mailroom -# -#, fuzzy -msgid "" -msgstr "" -"POT-Creation-Date: 2018-07-06 12:30+0000\n" -"Language: es\n" -"MIME-Version: 1.0\n" -"Content-Type: text/plain; charset=UTF-8\n" -"Language-3: spa\n" -"Source-Flows: 9de3663f-c5c5-4c92-9f45-ecbc09abcc85; 5890fe3a-f204-4661-b74d-025be4ee019c\n" - -#: Pick+a+Number/f90c9734-3e58-4c07-96cc-315266c8ecfd/name:0 -msgid "1-10" -msgstr "" - -#: Favorites/a602e75e-0814-4034-bb95-770906ddfe34/name:0 -#: Pick+a+Number/ee9c1a1d-3426-4f07-83c8-dc3c1949fe6c/name:0 -msgid "All Responses" -msgstr "" - -#: Favorites/baf07ebb-8a2a-4e63-aa08-d19aa408cd45/name:0 -msgid "Blue" -msgstr "" - -#: Favorites/6e367c0c-65ab-479a-82e3-c597d8e35eef/name:0 -msgid "Cyan" -msgstr "" - -#: Favorites/4cadf512-1299-468f-85e4-26af9edec193/text:0 -msgid "Good choice, I like @results.color.category_localized too! What is your favorite beer?" -msgstr "" - -#: Favorites/c102acfc-8cc5-41fa-89ed-41cbfa362ba6/name:0 -msgid "Green" -msgstr "" - -#: Favorites/66c38ec3-0acd-4bf7-a5d5-278af1bee492/text:0 -msgid "I don't know that color. Try again." -msgstr "" - -#: Favorites/0f0e66a8-9062-444f-b636-3d5374466e31/text:0 -msgid "I don't know that one, try again please." -msgstr "" - -#: Favorites/fc551cb4-e797-4076-b40a-433c44ad492b/text:0 -msgid "Mmmmm... delicious @results.beer.category_localized. If only they made @(lower(results.color)) @results.beer.category_localized! Lastly, what is your name?" -msgstr "" - -#: Favorites/b9d718d3-b5e0-4d26-998e-2da31b24f2f9/name:0 -msgid "Mutzig" -msgstr "" - -#: Favorites/7624633a-01a9-48f0-abca-957e7290df0a/name:0 -msgid "No Response" -msgstr "" - -#: Favorites/3ffb6f24-2ed8-4fd5-bcc0-b2e2668672a8/name:0 -#: Favorites/a813de57-c92a-4128-804d-56e80b332142/name:0 -#: Pick+a+Number/f3087862-dca9-4eaf-8cea-13f85cb52353/name:0 -msgid "Other" -msgstr "" - -#: Pick+a+Number/1b0564e8-c806-4b08-9e3d-06370d9c064c/text:0 -msgid "Pick a number between 1-10." -msgstr "" - -#: Favorites/f1ca9ac8-d0aa-4758-a969-195be7330267/name:0 -msgid "Primus" -msgstr "" - -#: Favorites/58284598-805a-4740-8966-dcb09e3b670a/name:0 -msgid "Red" -msgstr "" - -#: Favorites/52d7a9ab-52b7-4e82-ba7f-672fb8d6ec91/name:0 -msgid "Skol" -msgstr "" - -#: Favorites/1470d5e6-08dd-479b-a207-9b2b27b924d3/text:0 -msgid "Sorry you can't participate right now, I'll try again later." -msgstr "" - -#: Favorites/e92b12c5-1817-468e-aa2f-8791fb6247e9/text:0 -msgid "Thanks @results.name, we are all done!" -msgstr "" - -#: Favorites/dbc3b9d2-e6ce-4ebe-9552-8ddce482c1d1/name:0 -msgid "Turbo King" -msgstr "" - -#: Favorites/943f85bb-50bc-40c3-8d6f-57dbe34c87f7/text:0 -msgid "What is your favorite color?" -msgstr "" - -#: Pick+a+Number/41f97c7a-3397-4076-95ab-3f1aa9e2acb2/text:0 -msgid "You picked @results.number!" -msgstr "" - diff --git a/web/simulation/simulation.go b/web/simulation/simulation.go index 1d15f2163..e355ed1c5 100644 --- a/web/simulation/simulation.go +++ b/web/simulation/simulation.go @@ -9,7 +9,7 @@ import ( "github.com/nyaruka/goflow/assets" "github.com/nyaruka/goflow/assets/static" "github.com/nyaruka/goflow/excellent/tools" - xtypes "github.com/nyaruka/goflow/excellent/types" + "github.com/nyaruka/goflow/excellent/types" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/goflow/flows/events" "github.com/nyaruka/goflow/flows/resumes" @@ -60,24 +60,25 @@ func (r *sessionRequest) channels() []assets.Channel { } type simulationResponse struct { - Session flows.Session `json:"session"` - Events []flows.Event `json:"events"` - Context *xtypes.XObject `json:"context,omitempty"` + Session flows.Session `json:"session"` + Events []flows.Event `json:"events"` + Segments []flows.Segment `json:"segments"` + Context *types.XObject `json:"context,omitempty"` } func newSimulationResponse(session flows.Session, sprint flows.Sprint) *simulationResponse { - var context *xtypes.XObject + var context *types.XObject if session != nil { context = session.CurrentContext() // include object defaults which are not marshaled by default if context != nil { - tools.ContextWalkObjects(context, func(o *xtypes.XObject) { + tools.ContextWalkObjects(context, func(o *types.XObject) { o.SetMarshalDefault(true) }) } } - return &simulationResponse{Session: session, Events: sprint.Events(), Context: context} + return &simulationResponse{Session: session, Events: sprint.Events(), Segments: sprint.Segments(), Context: context} } // Starts a new engine session @@ -234,7 +235,7 @@ func handleResume(ctx context.Context, rt *runtime.Runtime, r *http.Request) (in } if triggeredFlow != nil { - tb := triggers.NewBuilder(oa.Env(), triggeredFlow.FlowReference(), resume.Contact()) + tb := triggers.NewBuilder(oa.Env(), triggeredFlow.Reference(), resume.Contact()) var sessionTrigger flows.Trigger if triggeredFlow.FlowType() == models.FlowTypeVoice { diff --git a/web/simulation/simulation_test.go b/web/simulation/simulation_test.go index 9bab48e28..0c3da1a51 100644 --- a/web/simulation/simulation_test.go +++ b/web/simulation/simulation_test.go @@ -199,7 +199,7 @@ const ( func TestServer(t *testing.T) { ctx, rt, db, _ := testsuite.Get() - defer testsuite.Reset(testsuite.ResetAll) + defer testsuite.Reset(testsuite.ResetData) wg := &sync.WaitGroup{} diff --git a/web/surveyor/surveyor.go b/web/surveyor/surveyor.go index 733ec0769..ead07db8f 100644 --- a/web/surveyor/surveyor.go +++ b/web/surveyor/surveyor.go @@ -125,7 +125,7 @@ func handleSubmit(ctx context.Context, rt *runtime.Runtime, r *http.Request) (in modifierEvents = append(modifierEvents, sessionEvents...) // create our sprint - sprint := engine.NewSprint(mods, modifierEvents) + sprint := engine.NewSprint(mods, modifierEvents, nil) // write our session out tx, err := rt.DB.BeginTxx(ctx, nil) diff --git a/web/surveyor/surveyor_test.go b/web/surveyor/surveyor_test.go index f28e007c2..a06022c56 100644 --- a/web/surveyor/surveyor_test.go +++ b/web/surveyor/surveyor_test.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "net/http" - "os" "path/filepath" "sync" "testing" @@ -114,8 +113,7 @@ func TestSurveyor(t *testing.T) { for i, tc := range tcs { testID := fmt.Sprintf("%s[token=%s]", tc.File, tc.Token) path := filepath.Join("testdata", tc.File) - submission, err := os.ReadFile(path) - assert.NoError(t, err) + submission := testsuite.ReadFile(path) url := "http://localhost:8090/mr/surveyor/submit" req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(submission)) diff --git a/web/testing.go b/web/testing.go index 66e2a6191..bcb21d4ac 100644 --- a/web/testing.go +++ b/web/testing.go @@ -16,6 +16,7 @@ import ( "time" "github.com/nyaruka/gocommon/dates" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/httpx" "github.com/nyaruka/gocommon/jsonx" "github.com/nyaruka/gocommon/uuids" @@ -62,14 +63,14 @@ func RunWebTests(t *testing.T, ctx context.Context, rt *runtime.Runtime, truthFi actualResponse []byte } tcs := make([]TestCase, 0, 20) - tcJSON, err := os.ReadFile(truthFile) - require.NoError(t, err) + tcJSON := testsuite.ReadFile(truthFile) for key, value := range substitutions { tcJSON = bytes.ReplaceAll(tcJSON, []byte("$"+key+"$"), []byte(value)) } jsonx.MustUnmarshal(tcJSON, &tcs) + var err error for i, tc := range tcs { dates.SetNowSource(dates.NewSequentialNowSource(time.Date(2018, 7, 6, 12, 30, 0, 123456789, time.UTC))) @@ -130,8 +131,7 @@ func RunWebTests(t *testing.T, ctx context.Context, rt *runtime.Runtime, truthFi expectedIsJSON := false if tc.ResponseFile != "" { - expectedResponse, err = os.ReadFile(tc.ResponseFile) - require.NoError(t, err) + expectedResponse = testsuite.ReadFile(tc.ResponseFile) expectedIsJSON = strings.HasSuffix(tc.ResponseFile, ".json") } else { @@ -146,7 +146,7 @@ func RunWebTests(t *testing.T, ctx context.Context, rt *runtime.Runtime, truthFi } for _, dba := range tc.DBAssertions { - testsuite.AssertQuery(t, rt.DB, dba.Query).Returns(dba.Count, "%s: '%s' returned wrong count", tc.Label, dba.Query) + assertdb.Query(t, rt.DB, dba.Query).Returns(dba.Count, "%s: '%s' returned wrong count", tc.Label, dba.Query) } } else { diff --git a/web/wrappers_test.go b/web/wrappers_test.go index c09dc235a..0fe7ff378 100644 --- a/web/wrappers_test.go +++ b/web/wrappers_test.go @@ -5,6 +5,7 @@ import ( "net/http" "testing" + "github.com/nyaruka/gocommon/dbutil/assertdb" "github.com/nyaruka/gocommon/httpx" "github.com/nyaruka/goflow/flows" "github.com/nyaruka/mailroom/core/models" @@ -60,5 +61,5 @@ func TestWithHTTPLogs(t *testing.T) { assert.NoError(t, err) // check HTTP logs were created - testsuite.AssertQuery(t, db, `select count(*) from request_logs_httplog where ticketer_id = $1;`, testdata.Mailgun.ID).Returns(2) + assertdb.Query(t, db, `select count(*) from request_logs_httplog where ticketer_id = $1;`, testdata.Mailgun.ID).Returns(2) }