From 357b5c47c6034c1fb556dbdbf5540fb49fcab8b4 Mon Sep 17 00:00:00 2001 From: Michael Derynck Date: Fri, 8 Nov 2024 19:42:11 -0700 Subject: [PATCH 01/12] Limit slack block text length when rendering alert group timeline (#5246) # What this PR does Limit length of text in block being posted to slack when showing alert group timeline. ## Which issue(s) this PR closes Related to [issue link here] ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- engine/apps/slack/scenarios/alertgroup_timeline.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/engine/apps/slack/scenarios/alertgroup_timeline.py b/engine/apps/slack/scenarios/alertgroup_timeline.py index 08f74b8802..7ca3a56f2d 100644 --- a/engine/apps/slack/scenarios/alertgroup_timeline.py +++ b/engine/apps/slack/scenarios/alertgroup_timeline.py @@ -2,6 +2,7 @@ from apps.api.permissions import RBACPermission from apps.slack.chatops_proxy_routing import make_private_metadata +from apps.slack.constants import BLOCK_SECTION_TEXT_MAX_SIZE from apps.slack.scenarios import scenario_step from apps.slack.scenarios.slack_renderer import AlertGroupLogSlackRenderer from apps.slack.types import ( @@ -47,9 +48,13 @@ def process_scenario( future_log_report = AlertGroupLogSlackRenderer.render_alert_group_future_log_report_text(alert_group) blocks: typing.List[Block.Section] = [] if past_log_report: - blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": past_log_report}}) + blocks.append( + {"type": "section", "text": {"type": "mrkdwn", "text": past_log_report[:BLOCK_SECTION_TEXT_MAX_SIZE]}} + ) if future_log_report: - blocks.append({"type": "section", "text": {"type": "mrkdwn", "text": future_log_report}}) + blocks.append( + {"type": "section", "text": {"type": "mrkdwn", "text": future_log_report[:BLOCK_SECTION_TEXT_MAX_SIZE]}} + ) view: ModalView = { "blocks": blocks, From df6bb69d29c7c496d70aaa583a7046787dfe68bd Mon Sep 17 00:00:00 2001 From: Dominik Broj Date: Tue, 12 Nov 2024 16:48:47 +0100 Subject: [PATCH 02/12] fix: disable accessControlOnCall for Grafana 11.3 (#5245) # What this PR does Disable accessControlOnCall for Grafana 11.3 ## Checklist - [ ] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- .github/workflows/linting-and-tests.yml | 1 + Tiltfile | 24 ++++++++++++++++++++++-- dev/helm-local.yml | 3 +++ helm/oncall/values.yaml | 3 +++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/.github/workflows/linting-and-tests.yml b/.github/workflows/linting-and-tests.yml index fc43b57276..23688595e6 100644 --- a/.github/workflows/linting-and-tests.yml +++ b/.github/workflows/linting-and-tests.yml @@ -244,6 +244,7 @@ jobs: grafana_version: - 10.3.0 - 11.2.0 + - latest fail-fast: false with: grafana_version: ${{ matrix.grafana_version }} diff --git a/Tiltfile b/Tiltfile index 264424161c..00d7ec4189 100644 --- a/Tiltfile +++ b/Tiltfile @@ -32,12 +32,23 @@ def plugin_json(): return plugin_file return 'NOT_A_PLUGIN' +def extra_grafana_ini(): + return { + 'feature_toggles': { + 'accessControlOnCall': 'false' + } + } + def extra_env(): return { "GF_APP_URL": grafana_url, "GF_SERVER_ROOT_URL": grafana_url, "GF_FEATURE_TOGGLES_ENABLE": "externalServiceAccounts", - "ONCALL_API_URL": "http://oncall-dev-engine:8080" + "ONCALL_API_URL": "http://oncall-dev-engine:8080", + + # Enables managed service accounts for plugin authentication in Grafana >= 11.3 + # https://grafana.com/docs/grafana/latest/setup-grafana/configure-grafana/#managed_service_accounts_enabled + "GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED": "true", } def extra_deps(): @@ -132,7 +143,16 @@ def load_grafana(): "GF_APP_URL": grafana_url, # older versions of grafana need this "GF_SERVER_ROOT_URL": grafana_url, "GF_FEATURE_TOGGLES_ENABLE": "externalServiceAccounts", - "ONCALL_API_URL": "http://oncall-dev-engine:8080" + "ONCALL_API_URL": "http://oncall-dev-engine:8080", + + # Enables managed service accounts for plugin authentication in Grafana >= 11.3 + # https://grafana.com/docs/grafana/latest/setup-grafana/configure-grafana/#managed_service_accounts_enabled + "GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED": "true", + }, + extra_grafana_ini={ + "feature_toggles": { + "accessControlOnCall": "false" + } }, ) # --- GRAFANA END ---- diff --git a/dev/helm-local.yml b/dev/helm-local.yml index 33a28790c6..8655df43fd 100644 --- a/dev/helm-local.yml +++ b/dev/helm-local.yml @@ -47,6 +47,8 @@ externalGrafana: grafana: enabled: false grafana.ini: + feature_toggles: + accessControlOnCall: false server: domain: localhost:3000 root_url: "%(protocol)s://%(domain)s" @@ -71,6 +73,7 @@ grafana: value: oncallpassword env: GF_FEATURE_TOGGLES_ENABLE: externalServiceAccounts + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true GF_SECURITY_ADMIN_PASSWORD: oncall GF_SECURITY_ADMIN_USER: oncall GF_PLUGINS_ALLOW_LOADING_UNSIGNED_PLUGINS: grafana-oncall-app diff --git a/helm/oncall/values.yaml b/helm/oncall/values.yaml index 8ca59a2664..826e0a5be3 100644 --- a/helm/oncall/values.yaml +++ b/helm/oncall/values.yaml @@ -639,6 +639,9 @@ grafana: serve_from_sub_path: true feature_toggles: enable: externalServiceAccounts + accessControlOnCall: false + env: + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true persistence: enabled: true # Disable psp as PodSecurityPolicy is deprecated in v1.21+, unavailable in v1.25+ From 9338cff0ef36661cdd1440724d8d163ad27fdc65 Mon Sep 17 00:00:00 2001 From: Michael Derynck Date: Thu, 14 Nov 2024 09:19:30 -0700 Subject: [PATCH 03/12] fix: disable accessControlonCall for Grafana 11.3 in docker compose (#5255) # What this PR does Disable accessControlOnCall for Grafana 11.3 in docker compose Similar to https://github.com/grafana/oncall/pull/5245 ## Checklist - [ ] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- docker-compose-developer.yml | 1 + docker-compose-mysql-rabbitmq.yml | 10 ++++++++++ docker-compose.yml | 10 ++++++++++ 3 files changed, 21 insertions(+) diff --git a/docker-compose-developer.yml b/docker-compose-developer.yml index b751ab1e98..ee668df794 100644 --- a/docker-compose-developer.yml +++ b/docker-compose-developer.yml @@ -324,6 +324,7 @@ services: GF_PLUGINS_ALLOW_LOADING_UNSIGNED_PLUGINS: grafana-oncall-app GF_FEATURE_TOGGLES_ENABLE: externalServiceAccounts ONCALL_API_URL: http://host.docker.internal:8080 + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true env_file: - ./dev/.env.${DB}.dev ports: diff --git a/docker-compose-mysql-rabbitmq.yml b/docker-compose-mysql-rabbitmq.yml index f587902e76..60b320e80f 100644 --- a/docker-compose-mysql-rabbitmq.yml +++ b/docker-compose-mysql-rabbitmq.yml @@ -144,6 +144,7 @@ services: GF_SECURITY_ADMIN_PASSWORD: ${GRAFANA_PASSWORD:-admin} GF_PLUGINS_ALLOW_LOADING_UNSIGNED_PLUGINS: grafana-oncall-app GF_INSTALL_PLUGINS: grafana-oncall-app + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true deploy: resources: limits: @@ -156,7 +157,16 @@ services: condition: service_healthy profiles: - with_grafana + configs: + - source: grafana.ini + target: /etc/grafana/grafana.ini volumes: dbdata: rabbitmqdata: + +configs: + grafana.ini: + content: | + [feature_toggles] + accessControlOnCall = false diff --git a/docker-compose.yml b/docker-compose.yml index b115199f8c..c54c2fb33f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -94,6 +94,7 @@ services: GF_SECURITY_ADMIN_PASSWORD: ${GRAFANA_PASSWORD:-admin} GF_PLUGINS_ALLOW_LOADING_UNSIGNED_PLUGINS: grafana-oncall-app GF_INSTALL_PLUGINS: grafana-oncall-app + GF_AUTH_MANAGED_SERVICE_ACCOUNTS_ENABLED: true volumes: - grafana_data:/var/lib/grafana deploy: @@ -103,9 +104,18 @@ services: cpus: "0.5" profiles: - with_grafana + configs: + - source: grafana.ini + target: /etc/grafana/grafana.ini volumes: grafana_data: prometheus_data: oncall_data: redis_data: + +configs: + grafana.ini: + content: | + [feature_toggles] + accessControlOnCall = false From 208db9cdb7a45a35867949aa3e97a8fbd59bb02a Mon Sep 17 00:00:00 2001 From: Salvatore Giordano Date: Fri, 15 Nov 2024 11:29:00 +0100 Subject: [PATCH 04/12] remove add_stack_slug_to_message_title utility from push notification titles (#5258) # What this PR does We noticed that the backend was adding the stack name to the notification title only on Android. We thought it makes sense to add the stack name only if the user has more than 1 stack connected, but that's not doable right now since the backend doesn't know how many stacks are connected in the app. Also we took a look at the analytics for the app and basically 95% of the users have only 1 stack connected. This pr removes the stack name from the notifications title. If in the future we think it makes sense to add it conditionally based on the number of stacks we can open another pr, but given the very little amount of users with more than 1 stack I think this is not needed. ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- engine/apps/mobile_app/demo_push.py | 4 ++-- .../apps/mobile_app/tasks/going_oncall_notification.py | 9 ++------- engine/apps/mobile_app/tasks/new_alert_group.py | 9 ++------- engine/apps/mobile_app/tasks/new_shift_swap_request.py | 9 ++------- .../tests/tasks/test_going_oncall_notification.py | 3 +-- .../tests/tasks/test_new_shift_swap_request.py | 7 ++----- engine/apps/mobile_app/tests/test_demo_push.py | 7 +++---- 7 files changed, 14 insertions(+), 34 deletions(-) diff --git a/engine/apps/mobile_app/demo_push.py b/engine/apps/mobile_app/demo_push.py index 19daca5b2f..01194c1487 100644 --- a/engine/apps/mobile_app/demo_push.py +++ b/engine/apps/mobile_app/demo_push.py @@ -8,7 +8,7 @@ from apps.mobile_app.exceptions import DeviceNotSet from apps.mobile_app.types import FCMMessageData, MessageType, Platform -from apps.mobile_app.utils import add_stack_slug_to_message_title, construct_fcm_message, send_push_notification +from apps.mobile_app.utils import construct_fcm_message, send_push_notification from apps.user_management.models import User if typing.TYPE_CHECKING: @@ -47,7 +47,7 @@ def _get_test_escalation_fcm_message(user: User, device_to_notify: "FCMDevice", apns_sound_name = mobile_app_user_settings.get_notification_sound_name(message_type, Platform.IOS) fcm_message_data: FCMMessageData = { - "title": add_stack_slug_to_message_title(get_test_push_title(critical), user.organization), + "title": get_test_push_title(critical), "orgName": user.organization.stack_slug, # Pass user settings, so the Android app can use them to play the correct sound and volume "default_notification_sound_name": mobile_app_user_settings.get_notification_sound_name( diff --git a/engine/apps/mobile_app/tasks/going_oncall_notification.py b/engine/apps/mobile_app/tasks/going_oncall_notification.py index 214fa19df8..34fd41607c 100644 --- a/engine/apps/mobile_app/tasks/going_oncall_notification.py +++ b/engine/apps/mobile_app/tasks/going_oncall_notification.py @@ -12,12 +12,7 @@ from firebase_admin.messaging import APNSPayload, Aps, ApsAlert, CriticalSound, Message from apps.mobile_app.types import FCMMessageData, MessageType, Platform -from apps.mobile_app.utils import ( - MAX_RETRIES, - add_stack_slug_to_message_title, - construct_fcm_message, - send_push_notification, -) +from apps.mobile_app.utils import MAX_RETRIES, construct_fcm_message, send_push_notification from apps.schedules.models.on_call_schedule import OnCallSchedule, ScheduleEvent from apps.user_management.models import User from common.cache import ensure_cache_key_allocates_to_the_same_hash_slot @@ -82,7 +77,7 @@ def _get_fcm_message( notification_subtitle = _get_notification_subtitle(schedule, schedule_event, mobile_app_user_settings) data: FCMMessageData = { - "title": add_stack_slug_to_message_title(notification_title, user.organization), + "title": notification_title, "subtitle": notification_subtitle, "orgName": user.organization.stack_slug, "info_notification_sound_name": mobile_app_user_settings.get_notification_sound_name( diff --git a/engine/apps/mobile_app/tasks/new_alert_group.py b/engine/apps/mobile_app/tasks/new_alert_group.py index e33e91112e..2b759f5f6e 100644 --- a/engine/apps/mobile_app/tasks/new_alert_group.py +++ b/engine/apps/mobile_app/tasks/new_alert_group.py @@ -8,12 +8,7 @@ from apps.alerts.models import AlertGroup from apps.mobile_app.alert_rendering import get_push_notification_subtitle, get_push_notification_title from apps.mobile_app.types import FCMMessageData, MessageType, Platform -from apps.mobile_app.utils import ( - MAX_RETRIES, - add_stack_slug_to_message_title, - construct_fcm_message, - send_push_notification, -) +from apps.mobile_app.utils import MAX_RETRIES, construct_fcm_message, send_push_notification from apps.user_management.models import User from common.custom_celery_tasks import shared_dedicated_queue_retry_task @@ -46,7 +41,7 @@ def _get_fcm_message(alert_group: AlertGroup, user: User, device_to_notify: "FCM apns_sound_name = mobile_app_user_settings.get_notification_sound_name(message_type, Platform.IOS) fcm_message_data: FCMMessageData = { - "title": add_stack_slug_to_message_title(alert_title, alert_group.channel.organization), + "title": alert_title, "subtitle": alert_subtitle, "orgId": alert_group.channel.organization.public_primary_key, "orgName": alert_group.channel.organization.stack_slug, diff --git a/engine/apps/mobile_app/tasks/new_shift_swap_request.py b/engine/apps/mobile_app/tasks/new_shift_swap_request.py index a6d49c8b20..3ab7167410 100644 --- a/engine/apps/mobile_app/tasks/new_shift_swap_request.py +++ b/engine/apps/mobile_app/tasks/new_shift_swap_request.py @@ -10,12 +10,7 @@ from firebase_admin.messaging import APNSPayload, Aps, ApsAlert, CriticalSound, Message from apps.mobile_app.types import FCMMessageData, MessageType, Platform -from apps.mobile_app.utils import ( - MAX_RETRIES, - add_stack_slug_to_message_title, - construct_fcm_message, - send_push_notification, -) +from apps.mobile_app.utils import MAX_RETRIES, construct_fcm_message, send_push_notification from apps.schedules.models import ShiftSwapRequest from apps.user_management.models import User from common.custom_celery_tasks import shared_dedicated_queue_retry_task @@ -121,7 +116,7 @@ def _get_fcm_message( route = f"/schedules/{shift_swap_request.schedule.public_primary_key}/ssrs/{shift_swap_request.public_primary_key}" data: FCMMessageData = { - "title": add_stack_slug_to_message_title(notification_title, user.organization), + "title": notification_title, "subtitle": notification_subtitle, "orgName": user.organization.stack_slug, "route": route, diff --git a/engine/apps/mobile_app/tests/tasks/test_going_oncall_notification.py b/engine/apps/mobile_app/tests/tasks/test_going_oncall_notification.py index 2541d507f9..051e4ffbfe 100644 --- a/engine/apps/mobile_app/tests/tasks/test_going_oncall_notification.py +++ b/engine/apps/mobile_app/tests/tasks/test_going_oncall_notification.py @@ -18,7 +18,6 @@ conditionally_send_going_oncall_push_notifications_for_schedule, ) from apps.mobile_app.types import MessageType, Platform -from apps.mobile_app.utils import add_stack_slug_to_message_title from apps.schedules.models import OnCallScheduleCalendar, OnCallScheduleICal, OnCallScheduleWeb from apps.schedules.models.on_call_schedule import ScheduleEvent @@ -228,7 +227,7 @@ def test_get_fcm_message( maus = MobileAppUserSettings.objects.create(user=user, time_zone=user_tz) data = { - "title": add_stack_slug_to_message_title(mock_notification_title, organization), + "title": mock_notification_title, "subtitle": mock_notification_subtitle, "orgName": organization.stack_slug, "info_notification_sound_name": maus.get_notification_sound_name(MessageType.INFO, Platform.ANDROID), diff --git a/engine/apps/mobile_app/tests/tasks/test_new_shift_swap_request.py b/engine/apps/mobile_app/tests/tasks/test_new_shift_swap_request.py index 452b98952f..f77674f8fc 100644 --- a/engine/apps/mobile_app/tests/tasks/test_new_shift_swap_request.py +++ b/engine/apps/mobile_app/tests/tasks/test_new_shift_swap_request.py @@ -19,7 +19,6 @@ notify_shift_swap_requests, notify_user_about_shift_swap_request, ) -from apps.mobile_app.utils import add_stack_slug_to_message_title from apps.schedules.models import CustomOnCallShift, OnCallScheduleWeb, ShiftSwapRequest from apps.user_management.models import User from apps.user_management.models.user import default_working_hours @@ -288,7 +287,7 @@ def test_notify_user_about_shift_swap_request( message: Message = mock_send_push_notification.call_args.args[1] assert message.data["type"] == "oncall.info" - assert message.data["title"] == add_stack_slug_to_message_title("New shift swap request", organization) + assert message.data["title"] == "New shift swap request" assert message.data["subtitle"] == "John Doe, Test Schedule" assert ( message.data["route"] @@ -487,9 +486,7 @@ def test_notify_beneficiary_about_taken_shift_swap_request( message: Message = mock_send_push_notification.call_args.args[1] assert message.data["type"] == "oncall.info" - assert message.data["title"] == add_stack_slug_to_message_title( - "Your shift swap request has been taken", organization - ) + assert message.data["title"] == "Your shift swap request has been taken" assert message.data["subtitle"] == schedule_name assert ( message.data["route"] diff --git a/engine/apps/mobile_app/tests/test_demo_push.py b/engine/apps/mobile_app/tests/test_demo_push.py index 769691f75e..abf5f6eb9f 100644 --- a/engine/apps/mobile_app/tests/test_demo_push.py +++ b/engine/apps/mobile_app/tests/test_demo_push.py @@ -2,7 +2,6 @@ from apps.mobile_app.demo_push import _get_test_escalation_fcm_message, get_test_push_title from apps.mobile_app.models import FCMDevice, MobileAppUserSettings -from apps.mobile_app.utils import add_stack_slug_to_message_title @pytest.mark.django_db @@ -34,7 +33,7 @@ def test_test_escalation_fcm_message_user_settings( # Check expected test push content assert message.apns.payload.aps.badge is None assert message.apns.payload.aps.alert.title == get_test_push_title(critical=False) - assert message.data["title"] == add_stack_slug_to_message_title(get_test_push_title(critical=False), organization) + assert message.data["title"] == get_test_push_title(critical=False) assert message.data["type"] == "oncall.message" @@ -68,7 +67,7 @@ def test_escalation_fcm_message_user_settings_critical( # Check expected test push content assert message.apns.payload.aps.badge is None assert message.apns.payload.aps.alert.title == get_test_push_title(critical=True) - assert message.data["title"] == add_stack_slug_to_message_title(get_test_push_title(critical=True), organization) + assert message.data["title"] == get_test_push_title(critical=True) assert message.data["type"] == "oncall.critical_message" @@ -94,4 +93,4 @@ def test_escalation_fcm_message_user_settings_critical_override_dnd_disabled( # Check expected test push content assert message.apns.payload.aps.badge is None assert message.apns.payload.aps.alert.title == get_test_push_title(critical=True) - assert message.data["title"] == add_stack_slug_to_message_title(get_test_push_title(critical=True), organization) + assert message.data["title"] == get_test_push_title(critical=True) From 10dc454c7b61a1bc98d6313a56e654c090c0abcc Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Mon, 18 Nov 2024 09:44:32 +0000 Subject: [PATCH 05/12] Inbound email improvements (#5259) # What this PR does * Allows to use multiple inbound email ESPs at the same time by setting the `INBOUND_EMAIL_ESP` env variable to `amazon_ses,mailgun` for example * Adds a new ESP `amazon_ses_validated` that performs SNS message vaildation (`django-anymail` doesn't implement it: [comment](https://github.com/anymail/django-anymail/blob/35383c7140289e82b39ada5980077898aa07d18d/anymail/webhooks/amazon_ses.py#L107-L108)) ## Which issue(s) this PR closes Related to https://github.com/grafana/oncall-private/issues/2905 ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- engine/apps/email/inbound.py | 83 ++-- engine/apps/email/tests/test_inbound_email.py | 450 ++++++++++++++++++ .../apps/email/validate_amazon_sns_message.py | 99 ++++ engine/settings/base.py | 1 + 4 files changed, 600 insertions(+), 33 deletions(-) create mode 100644 engine/apps/email/validate_amazon_sns_message.py diff --git a/engine/apps/email/inbound.py b/engine/apps/email/inbound.py index 1780f00c83..185234c521 100644 --- a/engine/apps/email/inbound.py +++ b/engine/apps/email/inbound.py @@ -1,27 +1,42 @@ import logging +from functools import cached_property from typing import Optional, TypedDict -from anymail.exceptions import AnymailInvalidAddress, AnymailWebhookValidationFailure +from anymail.exceptions import AnymailAPIError, AnymailInvalidAddress, AnymailWebhookValidationFailure from anymail.inbound import AnymailInboundMessage from anymail.signals import AnymailInboundEvent from anymail.webhooks import amazon_ses, mailgun, mailjet, mandrill, postal, postmark, sendgrid, sparkpost from django.http import HttpResponse, HttpResponseNotAllowed from django.utils import timezone from rest_framework import status -from rest_framework.request import Request from rest_framework.response import Response from rest_framework.views import APIView from apps.base.utils import live_settings +from apps.email.validate_amazon_sns_message import validate_amazon_sns_message from apps.integrations.mixins import AlertChannelDefiningMixin from apps.integrations.tasks import create_alert logger = logging.getLogger(__name__) +class AmazonSESValidatedInboundWebhookView(amazon_ses.AmazonSESInboundWebhookView): + # disable "Your Anymail webhooks are insecure and open to anyone on the web." warning + warn_if_no_basic_auth = False + + def validate_request(self, request): + """Add SNS message validation to Amazon SES inbound webhook view, which is not implemented in Anymail.""" + + super().validate_request(request) + sns_message = self._parse_sns_message(request) + if not validate_amazon_sns_message(sns_message): + raise AnymailWebhookValidationFailure("SNS message validation failed") + + # {: (, ), ...} INBOUND_EMAIL_ESP_OPTIONS = { "amazon_ses": (amazon_ses.AmazonSESInboundWebhookView, None), + "amazon_ses_validated": (AmazonSESValidatedInboundWebhookView, None), "mailgun": (mailgun.MailgunInboundWebhookView, "webhook_signing_key"), "mailjet": (mailjet.MailjetInboundWebhookView, "webhook_secret"), "mandrill": (mandrill.MandrillCombinedWebhookView, "webhook_key"), @@ -62,38 +77,33 @@ def dispatch(self, request): return super().dispatch(request, alert_channel_key=integration_token) def post(self, request): - timestamp = timezone.now().isoformat() - for message in self.get_messages_from_esp_request(request): - payload = self.get_alert_payload_from_email_message(message) - create_alert.delay( - title=payload["subject"], - message=payload["message"], - alert_receive_channel_pk=request.alert_receive_channel.pk, - image_url=None, - link_to_upstream_details=None, - integration_unique_data=None, - raw_request_data=payload, - received_at=timestamp, - ) - + payload = self.get_alert_payload_from_email_message(self.message) + create_alert.delay( + title=payload["subject"], + message=payload["message"], + alert_receive_channel_pk=request.alert_receive_channel.pk, + image_url=None, + link_to_upstream_details=None, + integration_unique_data=None, + raw_request_data=payload, + received_at=timezone.now().isoformat(), + ) return Response("OK", status=status.HTTP_200_OK) def get_integration_token_from_request(self, request) -> Optional[str]: - messages = self.get_messages_from_esp_request(request) - if not messages: + if not self.message: return None - message = messages[0] # First try envelope_recipient field. # According to AnymailInboundMessage it's provided not by all ESPs. - if message.envelope_recipient: - recipients = message.envelope_recipient.split(",") + if self.message.envelope_recipient: + recipients = self.message.envelope_recipient.split(",") for recipient in recipients: # if there is more than one recipient, the first matching the expected domain will be used try: token, domain = recipient.strip().split("@") except ValueError: logger.error( - f"get_integration_token_from_request: envelope_recipient field has unexpected format: {message.envelope_recipient}" + f"get_integration_token_from_request: envelope_recipient field has unexpected format: {self.message.envelope_recipient}" ) continue if domain == live_settings.INBOUND_EMAIL_DOMAIN: @@ -113,20 +123,27 @@ def get_integration_token_from_request(self, request) -> Optional[str]: # return cc.address.split("@")[0] return None - def get_messages_from_esp_request(self, request: Request) -> list[AnymailInboundMessage]: - view_class, secret_name = INBOUND_EMAIL_ESP_OPTIONS[live_settings.INBOUND_EMAIL_ESP] + @cached_property + def message(self) -> AnymailInboundMessage | None: + esps = live_settings.INBOUND_EMAIL_ESP.split(",") + for esp in esps: + view_class, secret_name = INBOUND_EMAIL_ESP_OPTIONS[esp] - kwargs = {secret_name: live_settings.INBOUND_EMAIL_WEBHOOK_SECRET} if secret_name else {} - view = view_class(**kwargs) + kwargs = {secret_name: live_settings.INBOUND_EMAIL_WEBHOOK_SECRET} if secret_name else {} + view = view_class(**kwargs) - try: - view.run_validators(request) - events = view.parse_events(request) - except AnymailWebhookValidationFailure as e: - logger.info(f"get_messages_from_esp_request: inbound email webhook validation failed: {e}") - return [] + try: + view.run_validators(self.request) + events = view.parse_events(self.request) + except (AnymailWebhookValidationFailure, AnymailAPIError) as e: + logger.info(f"inbound email webhook validation failed for ESP {esp}: {e}") + continue - return [event.message for event in events if isinstance(event, AnymailInboundEvent)] + messages = [event.message for event in events if isinstance(event, AnymailInboundEvent)] + if messages: + return messages[0] + + return None def check_inbound_email_settings_set(self): """ diff --git a/engine/apps/email/tests/test_inbound_email.py b/engine/apps/email/tests/test_inbound_email.py index 81a76e923a..35bccd10f2 100644 --- a/engine/apps/email/tests/test_inbound_email.py +++ b/engine/apps/email/tests/test_inbound_email.py @@ -1,13 +1,295 @@ +import datetime +import hashlib +import hmac import json +from base64 import b64encode from textwrap import dedent +from unittest.mock import ANY, Mock, patch import pytest from anymail.inbound import AnymailInboundMessage +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import padding, rsa +from cryptography.x509 import CertificateBuilder, NameOID +from django.conf import settings from django.urls import reverse from rest_framework import status from rest_framework.test import APIClient +from apps.alerts.models import AlertReceiveChannel from apps.email.inbound import InboundEmailWebhookView +from apps.integrations.tasks import create_alert + +PRIVATE_KEY = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, +) +ISSUER_NAME = x509.Name( + [ + x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Test"), + x509.NameAttribute(NameOID.LOCALITY_NAME, "Test"), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Amazon"), + x509.NameAttribute(NameOID.COMMON_NAME, "Test"), + ] +) +CERTIFICATE = ( + CertificateBuilder() + .subject_name(ISSUER_NAME) + .issuer_name(ISSUER_NAME) + .public_key(PRIVATE_KEY.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now() - datetime.timedelta(days=1)) + .not_valid_after(datetime.datetime.now() + datetime.timedelta(days=10)) + .sign(PRIVATE_KEY, hashes.SHA256()) + .public_bytes(serialization.Encoding.PEM) +) +AMAZON_SNS_TOPIC_ARN = "arn:aws:sns:us-east-2:123456789012:test" +SIGNING_CERT_URL = "https://sns.us-east-2.amazonaws.com/SimpleNotificationService-example.pem" + + +def _sns_inbound_email_payload_and_headers(sender_email, to_email, subject, message): + content = ( + f"From: Sender Name <{sender_email}>\n" + f"To: {to_email}\n" + f"Subject: {subject}\n" + "Date: Tue, 5 Nov 2024 16:05:39 +0000\n" + "Message-ID: \n\n" + f"{message}\r\n" + ) + + message = { + "notificationType": "Received", + "mail": { + "timestamp": "2024-11-05T16:05:52.387Z", + "source": sender_email, + "messageId": "example-message-id-5678", + "destination": [to_email], + "headersTruncated": False, + "headers": [ + {"name": "Return-Path", "value": f"<{sender_email}>"}, + { + "name": "Received", + "value": ( + f"from mail.example.com (mail.example.com [203.0.113.1]) " + f"by inbound-smtp.us-east-2.amazonaws.com with SMTP id example-id " + f"for {to_email}; Tue, 05 Nov 2024 16:05:52 +0000 (UTC)" + ), + }, + {"name": "X-SES-Spam-Verdict", "value": "PASS"}, + {"name": "X-SES-Virus-Verdict", "value": "PASS"}, + { + "name": "Received-SPF", + "value": ( + "pass (spfCheck: domain of example.com designates 203.0.113.1 as permitted sender) " + f"client-ip=203.0.113.1; envelope-from={sender_email}; helo=mail.example.com;" + ), + }, + { + "name": "Authentication-Results", + "value": ( + "amazonses.com; spf=pass (spfCheck: domain of example.com designates 203.0.113.1 as permitted sender) " + f"client-ip=203.0.113.1; envelope-from={sender_email}; helo=mail.example.com; " + "dkim=pass header.i=@example.com; dmarc=pass header.from=example.com;" + ), + }, + {"name": "X-SES-RECEIPT", "value": "example-receipt-data"}, + {"name": "X-SES-DKIM-SIGNATURE", "value": "example-dkim-signature"}, + { + "name": "Received", + "value": ( + f"by mail.example.com with SMTP id example-id for <{to_email}>; " + "Tue, 05 Nov 2024 08:05:52 -0800 (PST)" + ), + }, + { + "name": "DKIM-Signature", + "value": ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=example.com; s=default; t=1234567890; " + "bh=examplehash; h=From:To:Subject:Date:Message-ID; b=example-signature" + ), + }, + {"name": "X-Google-DKIM-Signature", "value": "example-google-dkim-signature"}, + {"name": "X-Gm-Message-State", "value": "example-message-state"}, + {"name": "X-Google-Smtp-Source", "value": "example-smtp-source"}, + { + "name": "X-Received", + "value": "by 2002:a17:example with SMTP id example-id; Tue, 05 Nov 2024 08:05:50 -0800 (PST)", + }, + {"name": "MIME-Version", "value": "1.0"}, + {"name": "From", "value": f"Sender Name <{sender_email}>"}, + {"name": "Date", "value": "Tue, 5 Nov 2024 16:05:39 +0000"}, + {"name": "Message-ID", "value": ""}, + {"name": "Subject", "value": subject}, + {"name": "To", "value": to_email}, + { + "name": "Content-Type", + "value": 'multipart/alternative; boundary="00000000000036b9f706262c9312"', + }, + ], + "commonHeaders": { + "returnPath": sender_email, + "from": [f"Sender Name <{sender_email}>"], + "date": "Tue, 5 Nov 2024 16:05:39 +0000", + "to": [to_email], + "messageId": "", + "subject": subject, + }, + }, + "receipt": { + "timestamp": "2024-11-05T16:05:52.387Z", + "processingTimeMillis": 638, + "recipients": [to_email], + "spamVerdict": {"status": "PASS"}, + "virusVerdict": {"status": "PASS"}, + "spfVerdict": {"status": "PASS"}, + "dkimVerdict": {"status": "PASS"}, + "dmarcVerdict": {"status": "PASS"}, + "action": { + "type": "SNS", + "topicArn": "arn:aws:sns:us-east-2:123456789012:test", + "encoding": "BASE64", + }, + }, + "content": b64encode(content.encode()).decode(), + } + + payload = { + "Type": "Notification", + "MessageId": "example-message-id-1234", + "TopicArn": AMAZON_SNS_TOPIC_ARN, + "Subject": "Amazon SES Email Receipt Notification", + "Message": json.dumps(message), + "Timestamp": "2024-11-05T16:05:53.041Z", + "SignatureVersion": "1", + "SigningCertURL": SIGNING_CERT_URL, + "UnsubscribeURL": ( + "https://sns.us-east-2.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=" + "arn:aws:sns:us-east-2:123456789012:test:example-subscription-id" + ), + } + # Sign the payload + canonical_message = "".join( + f"{key}\n{payload[key]}\n" for key in ("Message", "MessageId", "Subject", "Timestamp", "TopicArn", "Type") + ) + signature = PRIVATE_KEY.sign( + canonical_message.encode(), + padding.PKCS1v15(), + hashes.SHA1(), + ) + payload["Signature"] = b64encode(signature).decode() + + headers = { + "X-Amz-Sns-Message-Type": "Notification", + "X-Amz-Sns-Message-Id": "example-message-id-1234", + } + return payload, headers + + +def _mailgun_inbound_email_payload(sender_email, to_email, subject, message): + timestamp, token = "1731341416", "example-token" + signature = hmac.new( + key=settings.INBOUND_EMAIL_WEBHOOK_SECRET.encode("ascii"), + msg="{}{}".format(timestamp, token).encode("ascii"), + digestmod=hashlib.sha256, + ).hexdigest() + + return { + "Content-Type": 'multipart/alternative; boundary="000000000000267130626a556e5"', + "Date": "Mon, 11 Nov 2024 16:10:03 +0000", + "Dkim-Signature": ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=example.com; s=default; " + "t=1731341415; x=1731946215; darn=example.com; " + "h=to:subject:message-id:date:from:mime-version:from:to:cc:subject " + ":date:message-id:reply-to; bh=examplebh; b=exampleb" + ), + "From": f"Sender Name <{sender_email}>", + "Message-Id": "", + "Mime-Version": "1.0", + "Received": ( + f"by mail.example.com with SMTP id example-id for <{to_email}>; " "Mon, 11 Nov 2024 08:10:15 -0800 (PST)" + ), + "Subject": subject, + "To": to_email, + "X-Envelope-From": sender_email, + "X-Gm-Message-State": "example-message-state", + "X-Google-Dkim-Signature": ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=1e100.net; s=20230601; " + "t=1731341415; x=1731946215; " + "h=to:subject:message-id:date:from:mime-version:x-gm-message-state " + ":from:to:cc:subject:date:message-id:reply-to; bh=examplebh; b=exampleb" + ), + "X-Google-Smtp-Source": "example-smtp-source", + "X-Mailgun-Incoming": "Yes", + "X-Received": "by 2002:a17:example with SMTP id example-id; Mon, 11 Nov 2024 08:10:14 -0800 (PST)", + "body-html": f'
{message}
\r\n', + "body-plain": f"{message}\r\n", + "from": f"Sender Name <{sender_email}>", + "message-headers": json.dumps( + [ + ["X-Mailgun-Incoming", "Yes"], + ["X-Envelope-From", sender_email], + [ + "Received", + ( + "from mail.example.com (mail.example.com [203.0.113.1]) " + "by example.com with SMTP id example-id; " + "Mon, 11 Nov 2024 16:10:15 GMT" + ), + ], + [ + "Received", + ( + f"by mail.example.com with SMTP id example-id for <{to_email}>; " + "Mon, 11 Nov 2024 08:10:15 -0800 (PST)" + ), + ], + [ + "Dkim-Signature", + ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=example.com; s=default; " + "t=1731341415; x=1731946215; darn=example.com; " + "h=to:subject:message-id:date:from:mime-version:from:to:cc:subject " + ":date:message-id:reply-to; bh=examplebh; b=exampleb" + ), + ], + [ + "X-Google-Dkim-Signature", + ( + "v=1; a=rsa-sha256; c=relaxed/relaxed; d=1e100.net; s=20230601; " + "t=1731341415; x=1731946215; " + "h=to:subject:message-id:date:from:mime-version:x-gm-message-state " + ":from:to:cc:subject:date:message-id:reply-to; bh=examplebh; b=exampleb" + ), + ], + ["X-Gm-Message-State", "example-message-state"], + ["X-Google-Smtp-Source", "example-smtp-source"], + [ + "X-Received", + "by 2002:a17:example with SMTP id example-id; Mon, 11 Nov 2024 08:10:14 -0800 (PST)", + ], + ["Mime-Version", "1.0"], + ["From", f"Sender Name <{sender_email}>"], + ["Date", "Mon, 11 Nov 2024 16:10:03 +0000"], + ["Message-Id", ""], + ["Subject", subject], + ["To", to_email], + [ + "Content-Type", + 'multipart/alternative; boundary="000000000000267130626a556e5"', + ], + ] + ), + "recipient": to_email, + "sender": sender_email, + "signature": signature, + "stripped-html": f'
{message}
\n', + "stripped-text": f"{message}\n", + "subject": subject, + "timestamp": timestamp, + "token": token, + } @pytest.mark.parametrize( @@ -141,3 +423,171 @@ def test_get_sender_from_email_message(sender_value, expected_result): view = InboundEmailWebhookView() result = view.get_sender_from_email_message(email) assert result == expected_result + + +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_amazon_ses_pass(create_alert_mock, settings, make_organization, make_alert_receive_channel): + settings.INBOUND_EMAIL_ESP = "amazon_ses,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + + organization = make_organization() + alert_receive_channel = make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sender_email = "sender@example.com" + to_email = "test-token@inbound.example.com" + subject = "Test email" + message = "This is a test email message body." + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=sender_email, + to_email=to_email, + subject=subject, + message=message, + ) + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + + assert response.status_code == status.HTTP_200_OK + create_alert_mock.assert_called_once_with( + title=subject, + message=message, + alert_receive_channel_pk=alert_receive_channel.pk, + image_url=None, + link_to_upstream_details=None, + integration_unique_data=None, + raw_request_data={ + "subject": subject, + "message": message, + "sender": sender_email, + }, + received_at=ANY, + ) + + +@patch("requests.get", return_value=Mock(content=CERTIFICATE)) +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_amazon_ses_validated_pass( + mock_create_alert, mock_requests_get, settings, make_organization, make_alert_receive_channel +): + settings.INBOUND_EMAIL_ESP = "amazon_ses_validated,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = AMAZON_SNS_TOPIC_ARN + + organization = make_organization() + alert_receive_channel = make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sender_email = "sender@example.com" + to_email = "test-token@inbound.example.com" + subject = "Test email" + message = "This is a test email message body." + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=sender_email, + to_email=to_email, + subject=subject, + message=message, + ) + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + + assert response.status_code == status.HTTP_200_OK + mock_create_alert.assert_called_once_with( + title=subject, + message=message, + alert_receive_channel_pk=alert_receive_channel.pk, + image_url=None, + link_to_upstream_details=None, + integration_unique_data=None, + raw_request_data={ + "subject": subject, + "message": message, + "sender": sender_email, + }, + received_at=ANY, + ) + + mock_requests_get.assert_called_once_with(SIGNING_CERT_URL, timeout=5) + + +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_mailgun_pass(create_alert_mock, settings, make_organization, make_alert_receive_channel): + settings.INBOUND_EMAIL_ESP = "amazon_ses,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + + organization = make_organization() + alert_receive_channel = make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sender_email = "sender@example.com" + to_email = "test-token@inbound.example.com" + subject = "Test email" + message = "This is a test email message body." + + mailgun_payload = _mailgun_inbound_email_payload( + sender_email=sender_email, + to_email=to_email, + subject=subject, + message=message, + ) + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=mailgun_payload, + format="multipart", + ) + + assert response.status_code == status.HTTP_200_OK + create_alert_mock.assert_called_once_with( + title=subject, + message=message, + alert_receive_channel_pk=alert_receive_channel.pk, + image_url=None, + link_to_upstream_details=None, + integration_unique_data=None, + raw_request_data={ + "subject": subject, + "message": message, + "sender": sender_email, + }, + received_at=ANY, + ) + + +@pytest.mark.django_db +def test_multiple_esps_fail(settings): + settings.INBOUND_EMAIL_ESP = "amazon_ses,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + + client = APIClient() + response = client.post(reverse("integrations:inbound_email_webhook"), data={}) + + assert response.status_code == status.HTTP_400_BAD_REQUEST diff --git a/engine/apps/email/validate_amazon_sns_message.py b/engine/apps/email/validate_amazon_sns_message.py new file mode 100644 index 0000000000..f3d2aec482 --- /dev/null +++ b/engine/apps/email/validate_amazon_sns_message.py @@ -0,0 +1,99 @@ +import logging +import re +from base64 import b64decode +from urllib.parse import urlparse + +import requests +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 +from cryptography.hazmat.primitives.hashes import SHA1, SHA256 +from cryptography.x509 import NameOID, load_pem_x509_certificate +from django.conf import settings + +logger = logging.getLogger(__name__) + +HOST_PATTERN = re.compile(r"^sns\.[a-zA-Z0-9\-]{3,}\.amazonaws\.com(\.cn)?$") +REQUIRED_KEYS = ( + "Message", + "MessageId", + "Timestamp", + "TopicArn", + "Type", + "Signature", + "SigningCertURL", + "SignatureVersion", +) +SIGNING_KEYS_NOTIFICATION = ("Message", "MessageId", "Subject", "Timestamp", "TopicArn", "Type") +SIGNING_KEYS_SUBSCRIPTION = ("Message", "MessageId", "SubscribeURL", "Timestamp", "Token", "TopicArn", "Type") + + +def validate_amazon_sns_message(message: dict) -> bool: + """ + Validate an AWS SNS message. Based on: + - https://docs.aws.amazon.com/sns/latest/dg/sns-verify-signature-of-message.html + - https://github.com/aws/aws-js-sns-message-validator/blob/a6ba4d646dc60912653357660301f3b25f94d686/index.js + - https://github.com/aws/aws-php-sns-message-validator/blob/3cee0fc1aee5538e1bd677654b09fad811061d0b/src/MessageValidator.php + """ + + # Check if the message has all the required keys + if not all(key in message for key in REQUIRED_KEYS): + logger.warning("Missing required keys in the message, got: %s", message.keys()) + return False + + # Check TopicArn + if message["TopicArn"] != settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN: + logger.warning("Invalid TopicArn: %s", message["TopicArn"]) + return False + + # Construct the canonical message + if message["Type"] == "Notification": + signing_keys = SIGNING_KEYS_NOTIFICATION + elif message["Type"] in ("SubscriptionConfirmation", "UnsubscribeConfirmation"): + signing_keys = SIGNING_KEYS_SUBSCRIPTION + else: + logger.warning("Invalid message type: %s", message["Type"]) + return False + canonical_message = "".join(f"{key}\n{message[key]}\n" for key in signing_keys if key in message).encode() + + # Check if SigningCertURL is a valid SNS URL + signing_cert_url = message["SigningCertURL"] + parsed_url = urlparse(signing_cert_url) + if ( + parsed_url.scheme != "https" + or not HOST_PATTERN.match(parsed_url.netloc) + or not parsed_url.path.endswith(".pem") + ): + logger.warning("Invalid SigningCertURL: %s", signing_cert_url) + return False + + # Fetch the certificate + try: + response = requests.get(signing_cert_url, timeout=5) + response.raise_for_status() + certificate_bytes = response.content + except requests.RequestException as e: + logger.warning("Failed to fetch the certificate from %s: %s", signing_cert_url, e) + return False + + # Verify the certificate issuer + certificate = load_pem_x509_certificate(certificate_bytes) + if certificate.issuer.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)[0].value != "Amazon": + logger.warning("Invalid certificate issuer: %s", certificate.issuer) + return False + + # Verify the signature + signature = b64decode(message["Signature"]) + if message["SignatureVersion"] == "1": + hash_algorithm = SHA1() + elif message["SignatureVersion"] == "2": + hash_algorithm = SHA256() + else: + logger.warning("Invalid SignatureVersion: %s", message["SignatureVersion"]) + return False + try: + certificate.public_key().verify(signature, canonical_message, PKCS1v15(), hash_algorithm) + except InvalidSignature: + logger.warning("Invalid signature") + return False + + return True diff --git a/engine/settings/base.py b/engine/settings/base.py index 5b6eba8f14..25ef7dc142 100644 --- a/engine/settings/base.py +++ b/engine/settings/base.py @@ -867,6 +867,7 @@ class BrokerTypes: INBOUND_EMAIL_ESP = os.getenv("INBOUND_EMAIL_ESP") INBOUND_EMAIL_DOMAIN = os.getenv("INBOUND_EMAIL_DOMAIN") INBOUND_EMAIL_WEBHOOK_SECRET = os.getenv("INBOUND_EMAIL_WEBHOOK_SECRET") +INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = os.getenv("INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN") INSTALLED_ONCALL_INTEGRATIONS = [ # Featured From 5fbc3d058ca8ef0febe15e688307a1cab419e7fe Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Mon, 18 Nov 2024 12:09:05 +0000 Subject: [PATCH 06/12] Inbound email improvements (continued) (#5263) follow up to https://github.com/grafana/oncall/pull/5259: * Auto confirm SNS subsriptions for ESP `amazon_ses_validated` * Add a couple of tests for SNS message validation (try with wrong SNS topic ARN, try with wrong singature) --- engine/apps/email/inbound.py | 11 +- engine/apps/email/tests/test_inbound_email.py | 148 +++++++++++++----- 2 files changed, 115 insertions(+), 44 deletions(-) diff --git a/engine/apps/email/inbound.py b/engine/apps/email/inbound.py index 185234c521..6c86e19485 100644 --- a/engine/apps/email/inbound.py +++ b/engine/apps/email/inbound.py @@ -2,6 +2,7 @@ from functools import cached_property from typing import Optional, TypedDict +import requests from anymail.exceptions import AnymailAPIError, AnymailInvalidAddress, AnymailWebhookValidationFailure from anymail.inbound import AnymailInboundMessage from anymail.signals import AnymailInboundEvent @@ -26,12 +27,14 @@ class AmazonSESValidatedInboundWebhookView(amazon_ses.AmazonSESInboundWebhookVie def validate_request(self, request): """Add SNS message validation to Amazon SES inbound webhook view, which is not implemented in Anymail.""" - - super().validate_request(request) - sns_message = self._parse_sns_message(request) - if not validate_amazon_sns_message(sns_message): + if not validate_amazon_sns_message(self._parse_sns_message(request)): raise AnymailWebhookValidationFailure("SNS message validation failed") + def auto_confirm_sns_subscription(self, sns_message): + """This method is called after validate_request, so we can be sure that the message is valid.""" + response = requests.get(sns_message["SubscribeURL"]) + response.raise_for_status() + # {: (, ), ...} INBOUND_EMAIL_ESP_OPTIONS = { diff --git a/engine/apps/email/tests/test_inbound_email.py b/engine/apps/email/tests/test_inbound_email.py index 35bccd10f2..252b529208 100644 --- a/engine/apps/email/tests/test_inbound_email.py +++ b/engine/apps/email/tests/test_inbound_email.py @@ -47,6 +47,10 @@ ) AMAZON_SNS_TOPIC_ARN = "arn:aws:sns:us-east-2:123456789012:test" SIGNING_CERT_URL = "https://sns.us-east-2.amazonaws.com/SimpleNotificationService-example.pem" +SENDER_EMAIL = "sender@example.com" +TO_EMAIL = "test-token@inbound.example.com" +SUBJECT = "Test email" +MESSAGE = "This is a test email message body." def _sns_inbound_email_payload_and_headers(sender_email, to_email, subject, message): @@ -439,15 +443,11 @@ def test_amazon_ses_pass(create_alert_mock, settings, make_organization, make_al token="test-token", ) - sender_email = "sender@example.com" - to_email = "test-token@inbound.example.com" - subject = "Test email" - message = "This is a test email message body." sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( - sender_email=sender_email, - to_email=to_email, - subject=subject, - message=message, + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, ) client = APIClient() @@ -460,16 +460,16 @@ def test_amazon_ses_pass(create_alert_mock, settings, make_organization, make_al assert response.status_code == status.HTTP_200_OK create_alert_mock.assert_called_once_with( - title=subject, - message=message, + title=SUBJECT, + message=MESSAGE, alert_receive_channel_pk=alert_receive_channel.pk, image_url=None, link_to_upstream_details=None, integration_unique_data=None, raw_request_data={ - "subject": subject, - "message": message, - "sender": sender_email, + "subject": SUBJECT, + "message": MESSAGE, + "sender": SENDER_EMAIL, }, received_at=ANY, ) @@ -493,15 +493,11 @@ def test_amazon_ses_validated_pass( token="test-token", ) - sender_email = "sender@example.com" - to_email = "test-token@inbound.example.com" - subject = "Test email" - message = "This is a test email message body." sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( - sender_email=sender_email, - to_email=to_email, - subject=subject, - message=message, + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, ) client = APIClient() @@ -514,16 +510,16 @@ def test_amazon_ses_validated_pass( assert response.status_code == status.HTTP_200_OK mock_create_alert.assert_called_once_with( - title=subject, - message=message, + title=SUBJECT, + message=MESSAGE, alert_receive_channel_pk=alert_receive_channel.pk, image_url=None, link_to_upstream_details=None, integration_unique_data=None, raw_request_data={ - "subject": subject, - "message": message, - "sender": sender_email, + "subject": SUBJECT, + "message": MESSAGE, + "sender": SENDER_EMAIL, }, received_at=ANY, ) @@ -531,6 +527,83 @@ def test_amazon_ses_validated_pass( mock_requests_get.assert_called_once_with(SIGNING_CERT_URL, timeout=5) +@patch("requests.get", return_value=Mock(content=CERTIFICATE)) +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_amazon_ses_validated_fail_wrong_sns_topic_arn( + mock_create_alert, mock_requests_get, settings, make_organization, make_alert_receive_channel +): + settings.INBOUND_EMAIL_ESP = "amazon_ses_validated,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = "arn:aws:sns:us-east-2:123456789013:test" + + organization = make_organization() + make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, + ) + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + mock_create_alert.assert_not_called() + mock_requests_get.assert_not_called() + + +@patch("requests.get", return_value=Mock(content=CERTIFICATE)) +@patch.object(create_alert, "delay") +@pytest.mark.django_db +def test_amazon_ses_validated_fail_wrong_signature( + mock_create_alert, mock_requests_get, settings, make_organization, make_alert_receive_channel +): + settings.INBOUND_EMAIL_ESP = "amazon_ses_validated,mailgun" + settings.INBOUND_EMAIL_DOMAIN = "inbound.example.com" + settings.INBOUND_EMAIL_WEBHOOK_SECRET = "secret" + settings.INBOUND_EMAIL_AMAZON_SNS_TOPIC_ARN = AMAZON_SNS_TOPIC_ARN + + organization = make_organization() + make_alert_receive_channel( + organization, + integration=AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, + token="test-token", + ) + + sns_payload, sns_headers = _sns_inbound_email_payload_and_headers( + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, + ) + sns_payload["Signature"] = "invalid-signature" + + client = APIClient() + response = client.post( + reverse("integrations:inbound_email_webhook"), + data=sns_payload, + headers=sns_headers, + format="json", + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + mock_create_alert.assert_not_called() + mock_requests_get.assert_called_once_with(SIGNING_CERT_URL, timeout=5) + + @patch.object(create_alert, "delay") @pytest.mark.django_db def test_mailgun_pass(create_alert_mock, settings, make_organization, make_alert_receive_channel): @@ -545,16 +618,11 @@ def test_mailgun_pass(create_alert_mock, settings, make_organization, make_alert token="test-token", ) - sender_email = "sender@example.com" - to_email = "test-token@inbound.example.com" - subject = "Test email" - message = "This is a test email message body." - mailgun_payload = _mailgun_inbound_email_payload( - sender_email=sender_email, - to_email=to_email, - subject=subject, - message=message, + sender_email=SENDER_EMAIL, + to_email=TO_EMAIL, + subject=SUBJECT, + message=MESSAGE, ) client = APIClient() @@ -566,16 +634,16 @@ def test_mailgun_pass(create_alert_mock, settings, make_organization, make_alert assert response.status_code == status.HTTP_200_OK create_alert_mock.assert_called_once_with( - title=subject, - message=message, + title=SUBJECT, + message=MESSAGE, alert_receive_channel_pk=alert_receive_channel.pk, image_url=None, link_to_upstream_details=None, integration_unique_data=None, raw_request_data={ - "subject": subject, - "message": message, - "sender": sender_email, + "subject": SUBJECT, + "message": MESSAGE, + "sender": SENDER_EMAIL, }, received_at=ANY, ) From 0c811e0249cb1c6d8b91e9443646149b92d4c478 Mon Sep 17 00:00:00 2001 From: Matias Bordese Date: Mon, 18 Nov 2024 17:29:23 -0300 Subject: [PATCH 07/12] fix: update `next_shifts_per_user` to only list users with upcoming shifts (#5264) Related to https://github.com/grafana/irm/issues/343 --- engine/apps/api/tests/test_schedules.py | 40 +++++++++++++++++++++---- engine/apps/api/views/schedule.py | 18 ++++++----- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/engine/apps/api/tests/test_schedules.py b/engine/apps/api/tests/test_schedules.py index 4a29dc9dd6..8efcb6b236 100644 --- a/engine/apps/api/tests/test_schedules.py +++ b/engine/apps/api/tests/test_schedules.py @@ -1442,8 +1442,9 @@ def test_next_shifts_per_user( ("B", "UTC"), ("C", None), ("D", "America/Montevideo"), + ("E", None), ) - user_a, user_b, user_c, user_d = ( + user_a, user_b, user_c, user_d, user_e = ( make_user_for_organization(organization, username=i, _timezone=tz) for i, tz in users ) @@ -1469,8 +1470,7 @@ def test_next_shifts_per_user( ) on_call_shift.add_rolling_users([[user]]) - # override in the past: 17-18 / D - # won't be listed, but user D will still be included in the response + # override in the past, won't be listed: 17-18 / D override_data = { "start": tomorrow - timezone.timedelta(days=3), "rotation_start": tomorrow - timezone.timedelta(days=3), @@ -1483,6 +1483,7 @@ def test_next_shifts_per_user( override.add_rolling_users([[user_d]]) # override: 17-18 / C + # this is before C's shift, so it will be listed as upcoming override_data = { "start": tomorrow + timezone.timedelta(hours=17), "rotation_start": tomorrow + timezone.timedelta(hours=17), @@ -1494,11 +1495,26 @@ def test_next_shifts_per_user( ) override.add_rolling_users([[user_c]]) + # override: 17-18 / E + fifteend_days_later = tomorrow + timezone.timedelta(days=15) + override_data = { + "start": fifteend_days_later + timezone.timedelta(hours=17), + "rotation_start": fifteend_days_later + timezone.timedelta(hours=17), + "duration": timezone.timedelta(hours=1), + "schedule": schedule, + } + override = make_on_call_shift( + organization=organization, shift_type=CustomOnCallShift.TYPE_OVERRIDE, **override_data + ) + override.add_rolling_users([[user_e]]) + # final schedule: 7-12: B, 15-16: A, 16-17: B, 17-18: C (override), 18-20: C schedule.refresh_ical_final_schedule() url = reverse("api-internal:schedule-next-shifts-per-user", kwargs={"pk": schedule.public_primary_key}) - response = client.get(url, format="json", **make_user_auth_headers(admin, token)) + + # check for users with shifts in the next week + response = client.get(url + "?days=7", format="json", **make_user_auth_headers(admin, token)) assert response.status_code == status.HTTP_200_OK expected = { @@ -1517,13 +1533,27 @@ def test_next_shifts_per_user( tomorrow + timezone.timedelta(hours=18), user_c.timezone, ), - user_d.public_primary_key: (None, None, user_d.timezone), } returned_data = { u: (ev.get("start"), ev.get("end"), ev.get("user_timezone")) for u, ev in response.data["users"].items() } assert returned_data == expected + # by default it will check for shifts in the next 45 days + response = client.get(url, format="json", **make_user_auth_headers(admin, token)) + assert response.status_code == status.HTTP_200_OK + + # include user E with the override + expected[user_e.public_primary_key] = ( + fifteend_days_later + timezone.timedelta(hours=17), + fifteend_days_later + timezone.timedelta(hours=18), + user_e.timezone, + ) + returned_data = { + u: (ev.get("start"), ev.get("end"), ev.get("user_timezone")) for u, ev in response.data["users"].items() + } + assert returned_data == expected + @pytest.mark.django_db def test_next_shifts_per_user_ical_schedule_using_emails( diff --git a/engine/apps/api/views/schedule.py b/engine/apps/api/views/schedule.py index 78635290de..e30aa8cbde 100644 --- a/engine/apps/api/views/schedule.py +++ b/engine/apps/api/views/schedule.py @@ -388,20 +388,22 @@ def filter_shift_swaps(self, request: Request, pk: str) -> Response: @action(detail=True, methods=["get"]) def next_shifts_per_user(self, request, pk): """Return next shift for users in schedule.""" + days = self.request.query_params.get("days") + days = int(days) if days else 30 now = timezone.now() - datetime_end = now + datetime.timedelta(days=30) + datetime_end = now + datetime.timedelta(days=days) schedule = self.get_object(annotate=False) + users = {} events = schedule.final_events(now, datetime_end) - - # include user TZ information for every user - users = {u.public_primary_key: {"user_timezone": u.timezone} for u in schedule.related_users()} + users_tz = {u.public_primary_key: u.timezone for u in schedule.related_users()} added_users = set() for e in events: - user = e["users"][0]["pk"] if e["users"] else None - if user is not None and user not in added_users and user in users and e["end"] > now: - users[user].update(e) - added_users.add(user) + user_ppk = e["users"][0]["pk"] if e["users"] else None + if user_ppk is not None and user_ppk not in users and user_ppk in users_tz and e["end"] > now: + users[user_ppk] = e + users[user_ppk]["user_timezone"] = users_tz[user_ppk] + added_users.add(user_ppk) result = {"users": users} return Response(result, status=status.HTTP_200_OK) From 2bcbac8454904ae8e0e8783d41b32cb19ad2d7eb Mon Sep 17 00:00:00 2001 From: Matias Bordese Date: Tue, 19 Nov 2024 09:52:23 -0300 Subject: [PATCH 08/12] Enable service account token auth for public API (#5254) Related to https://github.com/grafana/oncall-private/issues/2826 Continuing work started in https://github.com/grafana/oncall/pull/5211, this adds support for Grafana service accounts tokens for API authentication (except alert group actions which will still require a user behind). Next steps would be updating the go client and the terraform provider to allow service account token auth for OnCall resources. Following proposal 1.1 from [doc](https://docs.google.com/document/d/1I3nFbsUEkiNPphBXT-kWefIeramTY71qqZ1OA06Kmls/edit?usp=sharing). --- ...065_alertreceivechannel_service_account.py | 20 ++ .../alerts/models/alert_receive_channel.py | 14 +- engine/apps/api/permissions.py | 1 + engine/apps/auth_token/auth.py | 43 +--- .../auth_token/grafana/grafana_auth_token.py | 6 + .../migrations/0007_serviceaccounttoken.py | 29 +++ engine/apps/auth_token/models/__init__.py | 1 + .../models/service_account_token.py | 110 +++++++++ engine/apps/auth_token/tests/helpers.py | 18 ++ .../auth_token/tests/test_grafana_auth.py | 229 +++++++++++++++++- engine/apps/grafana_plugin/helpers/client.py | 3 + .../public_api/serializers/integrations.py | 5 +- .../public_api/tests/test_alert_groups.py | 34 +++ .../public_api/tests/test_integrations.py | 44 ++++ .../public_api/tests/test_rbac_permissions.py | 104 ++++++++ .../public_api/tests/test_resolution_notes.py | 6 +- engine/apps/public_api/views/alert_groups.py | 25 +- engine/apps/public_api/views/alerts.py | 4 +- .../public_api/views/escalation_chains.py | 4 +- .../public_api/views/escalation_policies.py | 4 +- engine/apps/public_api/views/integrations.py | 4 +- .../apps/public_api/views/on_call_shifts.py | 4 +- engine/apps/public_api/views/organizations.py | 4 +- engine/apps/public_api/views/routes.py | 4 +- engine/apps/public_api/views/schedules.py | 8 +- engine/apps/public_api/views/shift_swap.py | 4 +- .../apps/public_api/views/slack_channels.py | 4 +- engine/apps/public_api/views/teams.py | 4 +- engine/apps/public_api/views/user_groups.py | 4 +- engine/apps/public_api/views/users.py | 8 +- engine/apps/public_api/views/webhooks.py | 4 +- .../migrations/0027_serviceaccount.py | 26 ++ .../apps/user_management/models/__init__.py | 1 + .../user_management/models/service_account.py | 55 +++++ .../apps/user_management/tests/factories.py | 10 +- engine/conftest.py | 36 ++- engine/engine/middlewares.py | 6 +- 37 files changed, 816 insertions(+), 74 deletions(-) create mode 100644 engine/apps/alerts/migrations/0065_alertreceivechannel_service_account.py create mode 100644 engine/apps/auth_token/migrations/0007_serviceaccounttoken.py create mode 100644 engine/apps/auth_token/models/service_account_token.py create mode 100644 engine/apps/auth_token/tests/helpers.py create mode 100644 engine/apps/user_management/migrations/0027_serviceaccount.py create mode 100644 engine/apps/user_management/models/service_account.py diff --git a/engine/apps/alerts/migrations/0065_alertreceivechannel_service_account.py b/engine/apps/alerts/migrations/0065_alertreceivechannel_service_account.py new file mode 100644 index 0000000000..306d8a0408 --- /dev/null +++ b/engine/apps/alerts/migrations/0065_alertreceivechannel_service_account.py @@ -0,0 +1,20 @@ +# Generated by Django 4.2.15 on 2024-11-12 13:13 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0027_serviceaccount'), + ('alerts', '0064_migrate_resolutionnoteslackmessage_slack_channel_id'), + ] + + operations = [ + migrations.AddField( + model_name='alertreceivechannel', + name='service_account', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='alert_receive_channels', to='user_management.serviceaccount'), + ), + ] diff --git a/engine/apps/alerts/models/alert_receive_channel.py b/engine/apps/alerts/models/alert_receive_channel.py index 4fd926ac47..a8cb1494d9 100644 --- a/engine/apps/alerts/models/alert_receive_channel.py +++ b/engine/apps/alerts/models/alert_receive_channel.py @@ -234,6 +234,13 @@ class AlertReceiveChannel(IntegrationOptionsMixin, MaintainableObject): author = models.ForeignKey( "user_management.User", on_delete=models.SET_NULL, related_name="alert_receive_channels", blank=True, null=True ) + service_account = models.ForeignKey( + "user_management.ServiceAccount", + on_delete=models.SET_NULL, + related_name="alert_receive_channels", + blank=True, + null=True, + ) team = models.ForeignKey( "user_management.Team", on_delete=models.SET_NULL, @@ -764,15 +771,16 @@ def listen_for_alertreceivechannel_model_save( from apps.heartbeat.models import IntegrationHeartBeat if created: - write_resource_insight_log(instance=instance, author=instance.author, event=EntityEvent.CREATED) + author = instance.author or instance.service_account + write_resource_insight_log(instance=instance, author=author, event=EntityEvent.CREATED) default_filter = ChannelFilter(alert_receive_channel=instance, filtering_term=None, is_default=True) default_filter.save() - write_resource_insight_log(instance=default_filter, author=instance.author, event=EntityEvent.CREATED) + write_resource_insight_log(instance=default_filter, author=author, event=EntityEvent.CREATED) TEN_MINUTES = 600 # this is timeout for cloud heartbeats if instance.is_available_for_integration_heartbeat: heartbeat = IntegrationHeartBeat.objects.create(alert_receive_channel=instance, timeout_seconds=TEN_MINUTES) - write_resource_insight_log(instance=heartbeat, author=instance.author, event=EntityEvent.CREATED) + write_resource_insight_log(instance=heartbeat, author=author, event=EntityEvent.CREATED) metrics_add_integrations_to_cache([instance], instance.organization) diff --git a/engine/apps/api/permissions.py b/engine/apps/api/permissions.py index 852506a109..d9dad6b37d 100644 --- a/engine/apps/api/permissions.py +++ b/engine/apps/api/permissions.py @@ -18,6 +18,7 @@ RBAC_PERMISSIONS_ATTR = "rbac_permissions" RBAC_OBJECT_PERMISSIONS_ATTR = "rbac_object_permissions" + ViewSetOrAPIView = typing.Union[ViewSet, APIView] diff --git a/engine/apps/auth_token/auth.py b/engine/apps/auth_token/auth.py index dc6ccf7ae0..3a7e25d6bd 100644 --- a/engine/apps/auth_token/auth.py +++ b/engine/apps/auth_token/auth.py @@ -9,7 +9,6 @@ from rest_framework.authentication import BaseAuthentication, get_authorization_header from rest_framework.request import Request -from apps.api.permissions import GrafanaAPIPermissions, LegacyAccessControlRole from apps.grafana_plugin.helpers.gcom import check_token from apps.grafana_plugin.sync_data import SyncPermission, SyncUser from apps.user_management.exceptions import OrganizationDeletedException, OrganizationMovedException @@ -20,13 +19,13 @@ from .constants import GOOGLE_OAUTH2_AUTH_TOKEN_NAME, SCHEDULE_EXPORT_TOKEN_NAME, SLACK_AUTH_TOKEN_NAME from .exceptions import InvalidToken -from .grafana.grafana_auth_token import get_service_account_token_permissions from .models import ( ApiAuthToken, GoogleOAuth2Token, IntegrationBacksyncAuthToken, PluginAuthToken, ScheduleExportAuthToken, + ServiceAccountToken, SlackAuthToken, UserScheduleExportAuthToken, ) @@ -336,8 +335,8 @@ def authenticate_credentials( return auth_token.user, auth_token +X_GRAFANA_URL = "X-Grafana-URL" X_GRAFANA_INSTANCE_ID = "X-Grafana-Instance-ID" -GRAFANA_SA_PREFIX = "glsa_" class GrafanaServiceAccountAuthentication(BaseAuthentication): @@ -345,7 +344,7 @@ def authenticate(self, request): auth = get_authorization_header(request).decode("utf-8") if not auth: raise exceptions.AuthenticationFailed("Invalid token.") - if not auth.startswith(GRAFANA_SA_PREFIX): + if not auth.startswith(ServiceAccountToken.GRAFANA_SA_PREFIX): return None organization = self.get_organization(request) @@ -359,6 +358,13 @@ def authenticate(self, request): return self.authenticate_credentials(organization, auth) def get_organization(self, request): + grafana_url = request.headers.get(X_GRAFANA_URL) + if grafana_url: + organization = Organization.objects.filter(grafana_url=grafana_url).first() + if not organization: + raise exceptions.AuthenticationFailed("Invalid Grafana URL.") + return organization + if settings.LICENSE == settings.CLOUD_LICENSE_NAME: instance_id = request.headers.get(X_GRAFANA_INSTANCE_ID) if not instance_id: @@ -370,36 +376,13 @@ def get_organization(self, request): return Organization.objects.filter(org_slug=org_slug, stack_slug=instance_slug).first() def authenticate_credentials(self, organization, token): - permissions = get_service_account_token_permissions(organization, token) - if not permissions: + try: + user, auth_token = ServiceAccountToken.validate_token(organization, token) + except InvalidToken: raise exceptions.AuthenticationFailed("Invalid token.") - role = LegacyAccessControlRole.NONE - if not organization.is_rbac_permissions_enabled: - role = self.determine_role_from_permissions(permissions) - - user = User( - organization_id=organization.pk, - name="Grafana Service Account", - username="grafana_service_account", - role=role, - permissions=GrafanaAPIPermissions.construct_permissions(permissions.keys()), - ) - - auth_token = ApiAuthToken(organization=organization, user=user, name="Grafana Service Account") - return user, auth_token - # Using default permissions as proxies for roles since we cannot explicitly get role from the service account token - def determine_role_from_permissions(self, permissions): - if "plugins:write" in permissions: - return LegacyAccessControlRole.ADMIN - if "dashboards:write" in permissions: - return LegacyAccessControlRole.EDITOR - if "dashboards:read" in permissions: - return LegacyAccessControlRole.VIEWER - return LegacyAccessControlRole.NONE - class IntegrationBacksyncAuthentication(BaseAuthentication): model = IntegrationBacksyncAuthToken diff --git a/engine/apps/auth_token/grafana/grafana_auth_token.py b/engine/apps/auth_token/grafana/grafana_auth_token.py index 07bae6446f..6576e41793 100644 --- a/engine/apps/auth_token/grafana/grafana_auth_token.py +++ b/engine/apps/auth_token/grafana/grafana_auth_token.py @@ -46,3 +46,9 @@ def get_service_account_token_permissions(organization: Organization, token: str grafana_api_client = GrafanaAPIClient(api_url=organization.grafana_url, api_token=token) permissions, _ = grafana_api_client.get_service_account_token_permissions() return permissions + + +def get_service_account_details(organization: Organization, token: str) -> typing.Dict[str, typing.List[str]]: + grafana_api_client = GrafanaAPIClient(api_url=organization.grafana_url, api_token=token) + user_data, _ = grafana_api_client.get_current_user() + return user_data diff --git a/engine/apps/auth_token/migrations/0007_serviceaccounttoken.py b/engine/apps/auth_token/migrations/0007_serviceaccounttoken.py new file mode 100644 index 0000000000..920b9ada3e --- /dev/null +++ b/engine/apps/auth_token/migrations/0007_serviceaccounttoken.py @@ -0,0 +1,29 @@ +# Generated by Django 4.2.15 on 2024-11-12 13:13 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0027_serviceaccount'), + ('auth_token', '0006_googleoauth2token'), + ] + + operations = [ + migrations.CreateModel( + name='ServiceAccountToken', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('token_key', models.CharField(db_index=True, max_length=8)), + ('digest', models.CharField(max_length=128)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('revoked_at', models.DateTimeField(null=True)), + ('service_account', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='tokens', to='user_management.serviceaccount')), + ], + options={ + 'unique_together': {('token_key', 'service_account', 'digest')}, + }, + ), + ] diff --git a/engine/apps/auth_token/models/__init__.py b/engine/apps/auth_token/models/__init__.py index 272adbda60..42cc60c516 100644 --- a/engine/apps/auth_token/models/__init__.py +++ b/engine/apps/auth_token/models/__init__.py @@ -4,5 +4,6 @@ from .integration_backsync_auth_token import IntegrationBacksyncAuthToken # noqa: F401 from .plugin_auth_token import PluginAuthToken # noqa: F401 from .schedule_export_auth_token import ScheduleExportAuthToken # noqa: F401 +from .service_account_token import ServiceAccountToken # noqa: F401 from .slack_auth_token import SlackAuthToken # noqa: F401 from .user_schedule_export_auth_token import UserScheduleExportAuthToken # noqa: F401 diff --git a/engine/apps/auth_token/models/service_account_token.py b/engine/apps/auth_token/models/service_account_token.py new file mode 100644 index 0000000000..716dc55db3 --- /dev/null +++ b/engine/apps/auth_token/models/service_account_token.py @@ -0,0 +1,110 @@ +import binascii +from hmac import compare_digest + +from django.db import models + +from apps.api.permissions import GrafanaAPIPermissions, LegacyAccessControlRole +from apps.auth_token import constants +from apps.auth_token.crypto import hash_token_string +from apps.auth_token.exceptions import InvalidToken +from apps.auth_token.grafana.grafana_auth_token import ( + get_service_account_details, + get_service_account_token_permissions, +) +from apps.auth_token.models import BaseAuthToken +from apps.user_management.models import ServiceAccount, ServiceAccountUser + + +class ServiceAccountTokenManager(models.Manager): + def get_queryset(self): + return super().get_queryset().select_related("service_account__organization") + + +class ServiceAccountToken(BaseAuthToken): + GRAFANA_SA_PREFIX = "glsa_" + + objects = ServiceAccountTokenManager() + + service_account: "ServiceAccount" + service_account = models.ForeignKey(ServiceAccount, on_delete=models.CASCADE, related_name="tokens") + + class Meta: + unique_together = ("token_key", "service_account", "digest") + + @property + def organization(self): + return self.service_account.organization + + @classmethod + def validate_token(cls, organization, token): + # require RBAC enabled to allow service account auth + if not organization.is_rbac_permissions_enabled: + raise InvalidToken + + # Grafana API request: get permissions and confirm token is valid + permissions = get_service_account_token_permissions(organization, token) + if not permissions: + # NOTE: a token can be disabled/re-enabled (not setting as revoked in oncall DB for now) + raise InvalidToken + + # check if we have already seen this token + validated_token = None + service_account = None + prefix_length = len(cls.GRAFANA_SA_PREFIX) + token_key = token[prefix_length : prefix_length + constants.TOKEN_KEY_LENGTH] + try: + hashable_token = binascii.hexlify(token.encode()).decode() + digest = hash_token_string(hashable_token) + except (TypeError, binascii.Error): + raise InvalidToken + for existing_token in cls.objects.filter(service_account__organization=organization, token_key=token_key): + if compare_digest(digest, existing_token.digest): + validated_token = existing_token + service_account = existing_token.service_account + break + + if not validated_token: + # if it didn't match an existing token, create a new one + # make request to Grafana API api/user using token + service_account_data = get_service_account_details(organization, token) + if not service_account_data: + # Grafana versions < 11.3 return 403 trying to get user details with service account token + # use some default values + service_account_data = { + "login": "grafana_service_account", + "uid": None, # "service-account:7" + } + + grafana_id = 0 # default to zero for old Grafana versions (to keep service account unique) + if service_account_data["uid"] is not None: + # extract service account Grafana ID + try: + grafana_id = int(service_account_data["uid"].split(":")[-1]) + except ValueError: + pass + + # get or create service account + service_account, _ = ServiceAccount.objects.get_or_create( + organization=organization, + grafana_id=grafana_id, + defaults={ + "login": service_account_data["login"], + }, + ) + # create token + validated_token, _ = cls.objects.get_or_create( + service_account=service_account, + token_key=token_key, + digest=digest, + ) + + user = ServiceAccountUser( + organization=organization, + service_account=service_account, + username=service_account.username, + public_primary_key=service_account.public_primary_key, + role=LegacyAccessControlRole.NONE, + permissions=GrafanaAPIPermissions.construct_permissions(permissions.keys()), + ) + + return user, validated_token diff --git a/engine/apps/auth_token/tests/helpers.py b/engine/apps/auth_token/tests/helpers.py new file mode 100644 index 0000000000..bcecce6f2c --- /dev/null +++ b/engine/apps/auth_token/tests/helpers.py @@ -0,0 +1,18 @@ +import json + +import httpretty + + +def setup_service_account_api_mocks(organization, perms=None, user_data=None, perms_status=200, user_status=200): + # requires enabling httpretty + if perms is None: + perms = {} + mock_response = httpretty.Response(status=perms_status, body=json.dumps(perms)) + perms_url = f"{organization.grafana_url}/api/access-control/user/permissions" + httpretty.register_uri(httpretty.GET, perms_url, responses=[mock_response]) + + if user_data is None: + user_data = {"login": "some-login", "uid": "service-account:42"} + mock_response = httpretty.Response(status=user_status, body=json.dumps(user_data)) + user_url = f"{organization.grafana_url}/api/user" + httpretty.register_uri(httpretty.GET, user_url, responses=[mock_response]) diff --git a/engine/apps/auth_token/tests/test_grafana_auth.py b/engine/apps/auth_token/tests/test_grafana_auth.py index 5b78636c4a..3a8ec56c0d 100644 --- a/engine/apps/auth_token/tests/test_grafana_auth.py +++ b/engine/apps/auth_token/tests/test_grafana_auth.py @@ -1,11 +1,16 @@ import typing from unittest.mock import patch +import httpretty import pytest from rest_framework import exceptions from rest_framework.test import APIRequestFactory -from apps.auth_token.auth import GRAFANA_SA_PREFIX, X_GRAFANA_INSTANCE_ID, GrafanaServiceAccountAuthentication +from apps.api.permissions import LegacyAccessControlRole +from apps.auth_token.auth import X_GRAFANA_INSTANCE_ID, GrafanaServiceAccountAuthentication +from apps.auth_token.models import ServiceAccountToken +from apps.auth_token.tests.helpers import setup_service_account_api_mocks +from apps.user_management.models import ServiceAccountUser from settings.base import CLOUD_LICENSE_NAME, OPEN_SOURCE_LICENSE_NAME, SELF_HOSTED_SETTINGS @@ -53,7 +58,7 @@ def test_grafana_authentication_cloud_inputs(make_organization, settings): mock.assert_called_once_with(organization, token) -def check_common_inputs() -> (dict[str, typing.Any], str): +def check_common_inputs() -> tuple[dict[str, typing.Any], str]: request = APIRequestFactory().get("/") with pytest.raises(exceptions.AuthenticationFailed): GrafanaServiceAccountAuthentication().authenticate(request) @@ -65,7 +70,7 @@ def check_common_inputs() -> (dict[str, typing.Any], str): result = GrafanaServiceAccountAuthentication().authenticate(request) assert result is None - token = f"{GRAFANA_SA_PREFIX}xyz" + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" headers = { "HTTP_AUTHORIZATION": token, } @@ -74,3 +79,221 @@ def check_common_inputs() -> (dict[str, typing.Any], str): GrafanaServiceAccountAuthentication().authenticate(request) return headers, token + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_missing_org(): + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" + headers = { + "HTTP_AUTHORIZATION": token, + } + request = APIRequestFactory().get("/", **headers) + + with pytest.raises(exceptions.AuthenticationFailed) as exc: + GrafanaServiceAccountAuthentication().authenticate(request) + assert exc.value.detail == "Invalid organization." + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_invalid_grafana_url(): + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" + headers = { + "HTTP_AUTHORIZATION": token, + "HTTP_X_GRAFANA_URL": "http://grafana.test", # no org for this URL + } + request = APIRequestFactory().get("/", **headers) + + with pytest.raises(exceptions.AuthenticationFailed) as exc: + GrafanaServiceAccountAuthentication().authenticate(request) + assert exc.value.detail == "Invalid Grafana URL." + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_rbac_disabled_fails(make_organization): + organization = make_organization(grafana_url="http://grafana.test") + if organization.is_rbac_permissions_enabled: + return + + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" + headers = { + "HTTP_AUTHORIZATION": token, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + with pytest.raises(exceptions.AuthenticationFailed) as exc: + GrafanaServiceAccountAuthentication().authenticate(request) + assert exc.value.detail == "Invalid token." + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_permissions_call_fails(make_organization): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}xyz" + headers = { + "HTTP_AUTHORIZATION": token, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + # permissions endpoint returns a 401 + setup_service_account_api_mocks(organization, perms_status=401) + + with pytest.raises(exceptions.AuthenticationFailed) as exc: + GrafanaServiceAccountAuthentication().authenticate(request) + assert exc.value.detail == "Invalid token." + + last_request = httpretty.last_request() + assert last_request.method == "GET" + expected_url = f"{organization.grafana_url}/api/access-control/user/permissions" + assert last_request.url == expected_url + # the request uses the given token + assert last_request.headers["Authorization"] == f"Bearer {token}" + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_existing_token( + make_organization, make_service_account_for_organization, make_token_for_service_account +): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + service_account = make_service_account_for_organization(organization) + token_string = "glsa_the-token" + token = make_token_for_service_account(service_account, token_string) + + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + setup_service_account_api_mocks(organization, {"some-perm": "value"}) + + user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request) + + assert isinstance(user, ServiceAccountUser) + assert user.service_account == service_account + assert user.public_primary_key == service_account.public_primary_key + assert user.username == service_account.username + assert user.role == LegacyAccessControlRole.NONE + assert auth_token == token + + last_request = httpretty.last_request() + assert last_request.method == "GET" + expected_url = f"{organization.grafana_url}/api/access-control/user/permissions" + assert last_request.url == expected_url + # the request uses the given token + assert last_request.headers["Authorization"] == f"Bearer {token_string}" + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_token_created(make_organization): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + token_string = "glsa_the-token" + + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + permissions = {"some-perm": "value"} + user_data = {"login": "some-login", "uid": "service-account:42"} + setup_service_account_api_mocks(organization, permissions, user_data) + + user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request) + + assert isinstance(user, ServiceAccountUser) + service_account = user.service_account + assert service_account.organization == organization + assert user.public_primary_key == service_account.public_primary_key + assert user.username == service_account.username + assert service_account.grafana_id == 42 + assert service_account.login == "some-login" + assert user.role == LegacyAccessControlRole.NONE + assert user.permissions == [{"action": p} for p in permissions] + assert auth_token.service_account == user.service_account + + perms_request, user_request = httpretty.latest_requests() + for req in (perms_request, user_request): + assert req.method == "GET" + assert req.headers["Authorization"] == f"Bearer {token_string}" + perms_url = f"{organization.grafana_url}/api/access-control/user/permissions" + assert perms_request.url == perms_url + user_url = f"{organization.grafana_url}/api/user" + assert user_request.url == user_url + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_token_created_older_grafana(make_organization): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + token_string = "glsa_the-token" + + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + permissions = {"some-perm": "value"} + # User API fails for older Grafana versions + setup_service_account_api_mocks(organization, permissions, user_status=400) + + user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request) + + assert isinstance(user, ServiceAccountUser) + service_account = user.service_account + assert service_account.organization == organization + # use fallback data + assert service_account.grafana_id == 0 + assert service_account.login == "grafana_service_account" + assert auth_token.service_account == user.service_account + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_grafana_authentication_token_reuse_service_account(make_organization, make_service_account_for_organization): + organization = make_organization(grafana_url="http://grafana.test") + if not organization.is_rbac_permissions_enabled: + return + service_account = make_service_account_for_organization(organization) + token_string = "glsa_the-token" + + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + request = APIRequestFactory().get("/", **headers) + + # setup Grafana API responses + permissions = {"some-perm": "value"} + user_data = { + "login": service_account.login, + "uid": f"service-account:{service_account.grafana_id}", + } + setup_service_account_api_mocks(organization, permissions, user_data) + + user, auth_token = GrafanaServiceAccountAuthentication().authenticate(request) + + assert isinstance(user, ServiceAccountUser) + assert user.service_account == service_account + assert auth_token.service_account == service_account diff --git a/engine/apps/grafana_plugin/helpers/client.py b/engine/apps/grafana_plugin/helpers/client.py index 2beafa8bdf..17d1cabd20 100644 --- a/engine/apps/grafana_plugin/helpers/client.py +++ b/engine/apps/grafana_plugin/helpers/client.py @@ -315,6 +315,9 @@ def get_grafana_labels_plugin_settings(self) -> APIClientResponse["GrafanaAPICli def get_grafana_irm_plugin_settings(self) -> APIClientResponse["GrafanaAPIClient.Types.PluginSettings"]: return self.get_grafana_plugin_settings(PluginID.IRM) + def get_current_user(self) -> APIClientResponse[typing.Dict[str, typing.List[str]]]: + return self.api_get("api/user") + def get_service_account(self, login: str) -> APIClientResponse["GrafanaAPIClient.Types.ServiceAccountResponse"]: return self.api_get(f"api/serviceaccounts/search?query={login}") diff --git a/engine/apps/public_api/serializers/integrations.py b/engine/apps/public_api/serializers/integrations.py index b16aeb5472..0cbf460583 100644 --- a/engine/apps/public_api/serializers/integrations.py +++ b/engine/apps/public_api/serializers/integrations.py @@ -7,6 +7,7 @@ from apps.alerts.models import AlertReceiveChannel from apps.base.messaging import get_messaging_backends from apps.integrations.legacy_prefix import has_legacy_prefix, remove_legacy_prefix +from apps.user_management.models import ServiceAccountUser from common.api_helpers.custom_fields import TeamPrimaryKeyRelatedField from common.api_helpers.exceptions import BadRequest from common.api_helpers.mixins import PHONE_CALL, SLACK, SMS, TELEGRAM, WEB, EagerLoadingMixin @@ -123,11 +124,13 @@ def create(self, validated_data): connection_error = GrafanaAlertingSyncManager.check_for_connection_errors(organization) if connection_error: raise serializers.ValidationError(connection_error) + user = self.context["request"].user with transaction.atomic(): try: instance = AlertReceiveChannel.create( **validated_data, - author=self.context["request"].user, + author=user if not isinstance(user, ServiceAccountUser) else None, + service_account=user.service_account if isinstance(user, ServiceAccountUser) else None, organization=organization, ) except AlertReceiveChannel.DuplicateDirectPagingError: diff --git a/engine/apps/public_api/tests/test_alert_groups.py b/engine/apps/public_api/tests/test_alert_groups.py index 71421cd318..e3cc872e3a 100644 --- a/engine/apps/public_api/tests/test_alert_groups.py +++ b/engine/apps/public_api/tests/test_alert_groups.py @@ -1,5 +1,6 @@ from unittest.mock import patch +import httpretty import pytest from django.urls import reverse from django.utils import timezone @@ -9,6 +10,8 @@ from apps.alerts.constants import ActionSource from apps.alerts.models import AlertGroup, AlertReceiveChannel from apps.alerts.tasks import delete_alert_group, wipe +from apps.api import permissions +from apps.auth_token.tests.helpers import setup_service_account_api_mocks def construct_expected_response_from_alert_groups(alert_groups): @@ -736,3 +739,34 @@ def test_alert_group_unsilence( assert alert_group.silenced == silenced assert response.status_code == status_code assert response_msg == response.json()["detail"] + + +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_actions_disabled_for_service_accounts( + make_organization, + make_service_account_for_organization, + make_token_for_service_account, + make_escalation_chain, +): + organization = make_organization(grafana_url="http://grafana.test") + service_account = make_service_account_for_organization(organization) + token_string = "glsa_token" + make_token_for_service_account(service_account, token_string) + make_escalation_chain(organization) + + perms = { + permissions.RBACPermission.Permissions.ALERT_GROUPS_WRITE.value: ["*"], + } + setup_service_account_api_mocks(organization, perms=perms) + + client = APIClient() + disabled_actions = ["acknowledge", "unacknowledge", "resolve", "unresolve", "silence", "unsilence"] + for action in disabled_actions: + url = reverse(f"api-public:alert_groups-{action}", kwargs={"pk": "ABCDEFG"}) + response = client.post( + url, + HTTP_AUTHORIZATION=f"{token_string}", + HTTP_X_GRAFANA_URL=organization.grafana_url, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN diff --git a/engine/apps/public_api/tests/test_integrations.py b/engine/apps/public_api/tests/test_integrations.py index b021df33e1..796942eb59 100644 --- a/engine/apps/public_api/tests/test_integrations.py +++ b/engine/apps/public_api/tests/test_integrations.py @@ -1,9 +1,12 @@ +import httpretty import pytest from django.urls import reverse from rest_framework import status from rest_framework.test import APIClient from apps.alerts.models import AlertReceiveChannel +from apps.api import permissions +from apps.auth_token.tests.helpers import setup_service_account_api_mocks from apps.base.tests.messaging_backend import TestOnlyBackend TEST_MESSAGING_BACKEND_FIELD = TestOnlyBackend.backend_id.lower() @@ -104,6 +107,47 @@ def test_create_integration( assert response.status_code == status.HTTP_201_CREATED +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_create_integration_via_service_account( + make_organization, + make_service_account_for_organization, + make_token_for_service_account, + make_escalation_chain, +): + organization = make_organization(grafana_url="http://grafana.test") + service_account = make_service_account_for_organization(organization) + token_string = "glsa_token" + make_token_for_service_account(service_account, token_string) + make_escalation_chain(organization) + + perms = { + permissions.RBACPermission.Permissions.INTEGRATIONS_WRITE.value: ["*"], + } + setup_service_account_api_mocks(organization, perms) + + client = APIClient() + data_for_create = { + "type": "grafana", + "name": "grafana_created", + "team_id": None, + } + url = reverse("api-public:integrations-list") + response = client.post( + url, + data=data_for_create, + format="json", + HTTP_AUTHORIZATION=f"{token_string}", + HTTP_X_GRAFANA_URL=organization.grafana_url, + ) + if not organization.is_rbac_permissions_enabled: + assert response.status_code == status.HTTP_403_FORBIDDEN + else: + assert response.status_code == status.HTTP_201_CREATED + integration = AlertReceiveChannel.objects.get(public_primary_key=response.data["id"]) + assert integration.service_account == service_account + + @pytest.mark.django_db def test_integration_name_uniqueness( make_organization_and_user_with_token, diff --git a/engine/apps/public_api/tests/test_rbac_permissions.py b/engine/apps/public_api/tests/test_rbac_permissions.py index 9829550d8c..95154ab4de 100644 --- a/engine/apps/public_api/tests/test_rbac_permissions.py +++ b/engine/apps/public_api/tests/test_rbac_permissions.py @@ -1,5 +1,7 @@ +import json from unittest.mock import patch +import httpretty import pytest from django.urls import reverse from rest_framework import status @@ -9,6 +11,13 @@ from apps.api.permissions import GrafanaAPIPermission, LegacyAccessControlRole, get_most_authorized_role from apps.public_api.urls import router +VIEWS_REQUIRING_USER_AUTH = ( + "EscalationView", + "PersonalNotificationView", + "MakeCallView", + "SendSMSView", +) + @pytest.mark.parametrize( "rbac_enabled,role,give_perm", @@ -96,3 +105,98 @@ def test_rbac_permissions( with patch(method_path, return_value=success): response = client.generic(path=url, method=http_method, HTTP_AUTHORIZATION=token) assert response.status_code == expected + + +@pytest.mark.parametrize( + "rbac_enabled,role,give_perm", + [ + # rbac disabled: auth is disabled + (False, LegacyAccessControlRole.ADMIN, None), + # rbac enabled: having role None, check the perm is required + (True, LegacyAccessControlRole.NONE, False), + (True, LegacyAccessControlRole.NONE, True), + ], +) +@pytest.mark.django_db +@httpretty.activate(verbose=True, allow_net_connect=False) +def test_service_account_auth( + make_organization, + make_service_account_for_organization, + make_token_for_service_account, + rbac_enabled, + role, + give_perm, +): + # APIView default actions + # (name, http method, detail-based) + default_actions = { + "create": ("post", False), + "list": ("get", False), + "retrieve": ("get", True), + "update": ("put", True), + "partial_update": ("patch", True), + "destroy": ("delete", True), + } + + organization = make_organization(grafana_url="http://grafana.test") + service_account = make_service_account_for_organization(organization) + token_string = "glsa_token" + make_token_for_service_account(service_account, token_string) + + if organization.is_rbac_permissions_enabled != rbac_enabled: + # skip if the organization's rbac_enabled is not the expected by the test + return + + client = APIClient() + # check all actions for all public API viewsets + for _, viewset, _basename in router.registry: + if viewset.__name__ == "ActionView": + # old actions (webhooks) are deprecated, no RBAC or service account support + continue + for viewset_method_name, required_perms in viewset.rbac_permissions.items(): + # setup Grafana API permissions response + if rbac_enabled: + permissions = {"perm": "value"} + expected = status.HTTP_403_FORBIDDEN + if give_perm: + permissions = {perm.value: "value" for perm in required_perms} + expected = status.HTTP_200_OK + mock_response = httpretty.Response(status=200, body=json.dumps(permissions)) + perms_url = f"{organization.grafana_url}/api/access-control/user/permissions" + httpretty.register_uri(httpretty.GET, perms_url, responses=[mock_response]) + else: + # service account auth is disabled + expected = status.HTTP_403_FORBIDDEN + + # iterate over all viewset actions, making an API request for each, + # using the user's token and confirming the response status code + if viewset_method_name in default_actions: + http_method, detail = default_actions[viewset_method_name] + else: + action_method = getattr(viewset, viewset_method_name) + http_method = list(action_method.mapping.keys())[0] + detail = action_method.detail + + method_path = f"{viewset.__module__}.{viewset.__name__}.{viewset_method_name}" + success = Response(status=status.HTTP_200_OK) + kwargs = {"pk": "NONEXISTENT"} if detail else None + if viewset_method_name in default_actions and detail: + url = reverse(f"api-public:{_basename}-detail", kwargs=kwargs) + elif viewset_method_name in default_actions and not detail: + url = reverse(f"api-public:{_basename}-list", kwargs=kwargs) + else: + name = viewset_method_name.replace("_", "-") + url = reverse(f"api-public:{_basename}-{name}", kwargs=kwargs) + + with patch(method_path, return_value=success): + headers = { + "HTTP_AUTHORIZATION": token_string, + "HTTP_X_GRAFANA_URL": organization.grafana_url, + } + response = client.generic(path=url, method=http_method, **headers) + assert ( + response.status_code == expected + if viewset.__name__ not in VIEWS_REQUIRING_USER_AUTH + # user-specific APIs do not support service account auth + else status.HTTP_403_FORBIDDEN + ) diff --git a/engine/apps/public_api/tests/test_resolution_notes.py b/engine/apps/public_api/tests/test_resolution_notes.py index c3a89a1da4..7a730e18ca 100644 --- a/engine/apps/public_api/tests/test_resolution_notes.py +++ b/engine/apps/public_api/tests/test_resolution_notes.py @@ -6,8 +6,8 @@ from rest_framework.test import APIClient from apps.alerts.models import ResolutionNote -from apps.auth_token.auth import GRAFANA_SA_PREFIX, ApiTokenAuthentication, GrafanaServiceAccountAuthentication -from apps.auth_token.models import ApiAuthToken +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication +from apps.auth_token.models import ApiAuthToken, ServiceAccountToken @pytest.mark.django_db @@ -366,7 +366,7 @@ def test_create_resolution_note_grafana_auth(make_organization_and_user, make_al mock_api_key_auth.assert_called_once() assert response.status_code == status.HTTP_403_FORBIDDEN - token = f"{GRAFANA_SA_PREFIX}123" + token = f"{ServiceAccountToken.GRAFANA_SA_PREFIX}123" # GrafanaServiceAccountAuthentication handle invalid token with patch( "apps.auth_token.auth.ApiTokenAuthentication.authenticate", wraps=api_token_auth.authenticate diff --git a/engine/apps/public_api/views/alert_groups.py b/engine/apps/public_api/views/alert_groups.py index d4f4a302ff..fc5d01d029 100644 --- a/engine/apps/public_api/views/alert_groups.py +++ b/engine/apps/public_api/views/alert_groups.py @@ -12,12 +12,13 @@ from apps.alerts.tasks import delete_alert_group, wipe from apps.api.label_filtering import parse_label_query from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.constants import VALID_DATE_FOR_DELETE_INCIDENT from apps.public_api.helpers import is_valid_group_creation_date, team_has_slack_token_for_deleting from apps.public_api.serializers import AlertGroupSerializer from apps.public_api.throttlers.user_throttle import UserThrottle -from common.api_helpers.exceptions import BadRequest +from apps.user_management.models import ServiceAccountUser +from common.api_helpers.exceptions import BadRequest, Forbidden from common.api_helpers.filters import ( NO_TEAM_VALUE, ByTeamModelFieldFilterMixin, @@ -57,7 +58,7 @@ class AlertGroupView( mixins.DestroyModelMixin, GenericViewSet, ): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { @@ -170,6 +171,9 @@ def destroy(self, request, *args, **kwargs): @action(methods=["post"], detail=True) def acknowledge(self, request, pk): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to acknowledge alert groups") + alert_group = self.get_object() if alert_group.acknowledged: @@ -189,6 +193,9 @@ def acknowledge(self, request, pk): @action(methods=["post"], detail=True) def unacknowledge(self, request, pk): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to unacknowledge alert groups") + alert_group = self.get_object() if not alert_group.acknowledged: @@ -208,6 +215,9 @@ def unacknowledge(self, request, pk): @action(methods=["post"], detail=True) def resolve(self, request, pk): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to resolve alert groups") + alert_group = self.get_object() if alert_group.resolved: @@ -225,6 +235,9 @@ def resolve(self, request, pk): @action(methods=["post"], detail=True) def unresolve(self, request, pk): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to unresolve alert groups") + alert_group = self.get_object() if not alert_group.resolved: @@ -241,6 +254,9 @@ def unresolve(self, request, pk): @action(methods=["post"], detail=True) def silence(self, request, pk=None): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to silence alert groups") + alert_group = self.get_object() delay = request.data.get("delay") @@ -267,6 +283,9 @@ def silence(self, request, pk=None): @action(methods=["post"], detail=True) def unsilence(self, request, pk=None): + if isinstance(request.user, ServiceAccountUser): + raise Forbidden(detail="Service accounts are not allowed to unsilence alert groups") + alert_group = self.get_object() if not alert_group.silenced: diff --git a/engine/apps/public_api/views/alerts.py b/engine/apps/public_api/views/alerts.py index b96d51c50c..0f3d1d4669 100644 --- a/engine/apps/public_api/views/alerts.py +++ b/engine/apps/public_api/views/alerts.py @@ -7,7 +7,7 @@ from apps.alerts.models import Alert from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.alerts import AlertSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.mixins import RateLimitHeadersMixin @@ -19,7 +19,7 @@ class AlertFilter(filters.FilterSet): class AlertView(RateLimitHeadersMixin, mixins.ListModelMixin, GenericViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/escalation_chains.py b/engine/apps/public_api/views/escalation_chains.py index 84bb71628d..52a1cc444c 100644 --- a/engine/apps/public_api/views/escalation_chains.py +++ b/engine/apps/public_api/views/escalation_chains.py @@ -5,7 +5,7 @@ from apps.alerts.models import EscalationChain from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import EscalationChainSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.filters import ByTeamFilter @@ -15,7 +15,7 @@ class EscalationChainView(RateLimitHeadersMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/escalation_policies.py b/engine/apps/public_api/views/escalation_policies.py index ddbaeae803..e91e52f48b 100644 --- a/engine/apps/public_api/views/escalation_policies.py +++ b/engine/apps/public_api/views/escalation_policies.py @@ -5,7 +5,7 @@ from apps.alerts.models import EscalationPolicy from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import EscalationPolicySerializer, EscalationPolicyUpdateSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.mixins import RateLimitHeadersMixin, UpdateSerializerMixin @@ -14,7 +14,7 @@ class EscalationPolicyView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/integrations.py b/engine/apps/public_api/views/integrations.py index 26c55224fd..e8ec9a852b 100644 --- a/engine/apps/public_api/views/integrations.py +++ b/engine/apps/public_api/views/integrations.py @@ -5,7 +5,7 @@ from apps.alerts.models import AlertReceiveChannel from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import IntegrationSerializer, IntegrationUpdateSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.exceptions import BadRequest @@ -24,7 +24,7 @@ class IntegrationView( MaintainableObjectMixin, ModelViewSet, ): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/on_call_shifts.py b/engine/apps/public_api/views/on_call_shifts.py index e825ea3537..2e091e947c 100644 --- a/engine/apps/public_api/views/on_call_shifts.py +++ b/engine/apps/public_api/views/on_call_shifts.py @@ -5,7 +5,7 @@ from rest_framework.viewsets import ModelViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import CustomOnCallShiftSerializer, CustomOnCallShiftUpdateSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from apps.schedules.models import CustomOnCallShift @@ -16,7 +16,7 @@ class CustomOnCallShiftView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/organizations.py b/engine/apps/public_api/views/organizations.py index 1df2f63a5d..473d79de6c 100644 --- a/engine/apps/public_api/views/organizations.py +++ b/engine/apps/public_api/views/organizations.py @@ -3,7 +3,7 @@ from rest_framework.viewsets import ReadOnlyModelViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import OrganizationSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from apps.user_management.models import Organization @@ -15,7 +15,7 @@ class OrganizationView( RateLimitHeadersMixin, ReadOnlyModelViewSet, ): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/routes.py b/engine/apps/public_api/views/routes.py index 7946152718..19ddc1056a 100644 --- a/engine/apps/public_api/views/routes.py +++ b/engine/apps/public_api/views/routes.py @@ -7,7 +7,7 @@ from apps.alerts.models import ChannelFilter from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers import ChannelFilterSerializer, ChannelFilterUpdateSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from common.api_helpers.exceptions import BadRequest @@ -17,7 +17,7 @@ class ChannelFilterView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/schedules.py b/engine/apps/public_api/views/schedules.py index 6dcca6fd08..5960ad4894 100644 --- a/engine/apps/public_api/views/schedules.py +++ b/engine/apps/public_api/views/schedules.py @@ -9,7 +9,11 @@ from rest_framework.viewsets import ModelViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication, ScheduleExportAuthentication +from apps.auth_token.auth import ( + ApiTokenAuthentication, + GrafanaServiceAccountAuthentication, + ScheduleExportAuthentication, +) from apps.public_api.custom_renderers import CalendarRenderer from apps.public_api.serializers import PolymorphicScheduleSerializer, PolymorphicScheduleUpdateSerializer from apps.public_api.serializers.schedules_base import FinalShiftQueryParamsSerializer @@ -28,7 +32,7 @@ class OnCallScheduleChannelView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/shift_swap.py b/engine/apps/public_api/views/shift_swap.py index 07f978e5c9..c46c141965 100644 --- a/engine/apps/public_api/views/shift_swap.py +++ b/engine/apps/public_api/views/shift_swap.py @@ -10,7 +10,7 @@ from apps.api.permissions import AuthenticatedRequest, RBACPermission from apps.api.views.shift_swap import BaseShiftSwapViewSet -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.throttlers.user_throttle import UserThrottle from apps.schedules.models import ShiftSwapRequest from apps.user_management.models import User @@ -23,7 +23,7 @@ class ShiftSwapViewSet(RateLimitHeadersMixin, BaseShiftSwapViewSet): # set authentication and permission classes - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/slack_channels.py b/engine/apps/public_api/views/slack_channels.py index 77581f3dde..35f384021a 100644 --- a/engine/apps/public_api/views/slack_channels.py +++ b/engine/apps/public_api/views/slack_channels.py @@ -3,7 +3,7 @@ from rest_framework.viewsets import GenericViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.slack_channel import SlackChannelSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from apps.slack.models import SlackChannel @@ -12,7 +12,7 @@ class SlackChannelView(RateLimitHeadersMixin, mixins.ListModelMixin, GenericViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/teams.py b/engine/apps/public_api/views/teams.py index 490e74efb1..6d399bade5 100644 --- a/engine/apps/public_api/views/teams.py +++ b/engine/apps/public_api/views/teams.py @@ -3,7 +3,7 @@ from rest_framework.permissions import IsAuthenticated from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.teams import TeamSerializer from apps.public_api.tf_sync import is_request_from_terraform, sync_teams_on_tf_request from apps.public_api.throttlers.user_throttle import UserThrottle @@ -14,7 +14,7 @@ class TeamView(PublicPrimaryKeyMixin, RetrieveModelMixin, ListModelMixin, viewsets.GenericViewSet): serializer_class = TeamSerializer - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/user_groups.py b/engine/apps/public_api/views/user_groups.py index ced7f626bf..bb1dac7f37 100644 --- a/engine/apps/public_api/views/user_groups.py +++ b/engine/apps/public_api/views/user_groups.py @@ -3,7 +3,7 @@ from rest_framework.viewsets import GenericViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.user_groups import UserGroupSerializer from apps.public_api.throttlers.user_throttle import UserThrottle from apps.slack.models import SlackUserGroup @@ -12,7 +12,7 @@ class UserGroupView(RateLimitHeadersMixin, mixins.ListModelMixin, GenericViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/users.py b/engine/apps/public_api/views/users.py index 97315fe202..129096e560 100644 --- a/engine/apps/public_api/views/users.py +++ b/engine/apps/public_api/views/users.py @@ -6,7 +6,11 @@ from rest_framework.viewsets import ReadOnlyModelViewSet from apps.api.permissions import LegacyAccessControlRole, RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication, UserScheduleExportAuthentication +from apps.auth_token.auth import ( + ApiTokenAuthentication, + GrafanaServiceAccountAuthentication, + UserScheduleExportAuthentication, +) from apps.public_api.custom_renderers import CalendarRenderer from apps.public_api.serializers import FastUserSerializer, UserSerializer from apps.public_api.tf_sync import is_request_from_terraform, sync_users_on_tf_request @@ -35,7 +39,7 @@ class Meta: class UserView(RateLimitHeadersMixin, ShortSerializerMixin, ReadOnlyModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/public_api/views/webhooks.py b/engine/apps/public_api/views/webhooks.py index 8f75148b71..b1a6a47bb1 100644 --- a/engine/apps/public_api/views/webhooks.py +++ b/engine/apps/public_api/views/webhooks.py @@ -6,7 +6,7 @@ from rest_framework.viewsets import ModelViewSet from apps.api.permissions import RBACPermission -from apps.auth_token.auth import ApiTokenAuthentication +from apps.auth_token.auth import ApiTokenAuthentication, GrafanaServiceAccountAuthentication from apps.public_api.serializers.webhooks import ( WebhookCreateSerializer, WebhookResponseSerializer, @@ -21,7 +21,7 @@ class WebhooksView(RateLimitHeadersMixin, UpdateSerializerMixin, ModelViewSet): - authentication_classes = (ApiTokenAuthentication,) + authentication_classes = (GrafanaServiceAccountAuthentication, ApiTokenAuthentication) permission_classes = (IsAuthenticated, RBACPermission) rbac_permissions = { diff --git a/engine/apps/user_management/migrations/0027_serviceaccount.py b/engine/apps/user_management/migrations/0027_serviceaccount.py new file mode 100644 index 0000000000..dc9e520b3b --- /dev/null +++ b/engine/apps/user_management/migrations/0027_serviceaccount.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.15 on 2024-11-12 13:13 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0026_auto_20241017_1919'), + ] + + operations = [ + migrations.CreateModel( + name='ServiceAccount', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('grafana_id', models.PositiveIntegerField()), + ('login', models.CharField(max_length=300)), + ('organization', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='service_accounts', to='user_management.organization')), + ], + options={ + 'unique_together': {('grafana_id', 'organization')}, + }, + ), + ] diff --git a/engine/apps/user_management/models/__init__.py b/engine/apps/user_management/models/__init__.py index e2bcd4c7f0..2fd5a9aa1e 100644 --- a/engine/apps/user_management/models/__init__.py +++ b/engine/apps/user_management/models/__init__.py @@ -1,4 +1,5 @@ from .user import User # noqa: F401, isort: skip from .organization import Organization # noqa: F401 from .region import Region # noqa: F401 +from .service_account import ServiceAccount, ServiceAccountUser # noqa: F401 from .team import Team # noqa: F401 diff --git a/engine/apps/user_management/models/service_account.py b/engine/apps/user_management/models/service_account.py new file mode 100644 index 0000000000..5082f7b965 --- /dev/null +++ b/engine/apps/user_management/models/service_account.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass +from typing import List + +from django.db import models + +from apps.user_management.models import Organization + + +@dataclass +class ServiceAccountUser: + """Authenticated service account in public API requests.""" + + service_account: "ServiceAccount" + organization: "Organization" # required for insight logs interface + username: str # required for insight logs interface + public_primary_key: str # required for insight logs interface + role: str # required for permissions check + permissions: List[str] # required for permissions check + + @property + def id(self): + return self.service_account.id + + @property + def pk(self): + return self.service_account.id + + @property + def organization_id(self): + return self.organization.id + + @property + def is_authenticated(self): + return True + + +class ServiceAccount(models.Model): + organization: "Organization" + + grafana_id = models.PositiveIntegerField() + organization = models.ForeignKey(Organization, on_delete=models.CASCADE, related_name="service_accounts") + login = models.CharField(max_length=300) + + class Meta: + unique_together = ("grafana_id", "organization") + + @property + def username(self): + # required for insight logs interface + return self.login + + @property + def public_primary_key(self): + # required for insight logs interface + return f"service-account:{self.grafana_id}" diff --git a/engine/apps/user_management/tests/factories.py b/engine/apps/user_management/tests/factories.py index ccfbb8586e..a33aefaca1 100644 --- a/engine/apps/user_management/tests/factories.py +++ b/engine/apps/user_management/tests/factories.py @@ -1,6 +1,6 @@ import factory -from apps.user_management.models import Organization, Region, Team, User +from apps.user_management.models import Organization, Region, ServiceAccount, Team, User from common.utils import UniqueFaker @@ -41,3 +41,11 @@ class RegionFactory(factory.DjangoModelFactory): class Meta: model = Region + + +class ServiceAccountFactory(factory.DjangoModelFactory): + grafana_id = UniqueFaker("pyint") + login = UniqueFaker("user_name") + + class Meta: + model = ServiceAccount diff --git a/engine/conftest.py b/engine/conftest.py index a95383dd94..0b66e3adea 100644 --- a/engine/conftest.py +++ b/engine/conftest.py @@ -1,3 +1,4 @@ +import binascii import datetime import json import os @@ -46,11 +47,14 @@ LegacyAccessControlRole, RBACPermission, ) +from apps.auth_token import constants as auth_token_constants +from apps.auth_token.crypto import hash_token_string from apps.auth_token.models import ( ApiAuthToken, GoogleOAuth2Token, IntegrationBacksyncAuthToken, PluginAuthToken, + ServiceAccountToken, SlackAuthToken, ) from apps.base.models.user_notification_policy_log_record import ( @@ -102,7 +106,13 @@ TelegramVerificationCodeFactory, ) from apps.user_management.models.user import User, listen_for_user_model_save -from apps.user_management.tests.factories import OrganizationFactory, RegionFactory, TeamFactory, UserFactory +from apps.user_management.tests.factories import ( + OrganizationFactory, + RegionFactory, + ServiceAccountFactory, + TeamFactory, + UserFactory, +) from apps.webhooks.presets.preset_options import WebhookPresetOptions from apps.webhooks.tests.factories import CustomWebhookFactory, WebhookResponseFactory from apps.webhooks.tests.test_webhook_presets import ( @@ -252,6 +262,30 @@ def _make_user_for_organization(organization, role: typing.Optional[LegacyAccess return _make_user_for_organization +@pytest.fixture +def make_service_account_for_organization(make_user): + def _make_service_account_for_organization(organization, **kwargs): + return ServiceAccountFactory(organization=organization, **kwargs) + + return _make_service_account_for_organization + + +@pytest.fixture +def make_token_for_service_account(): + def _make_token_for_service_account(service_account, token_string): + prefix_length = len(ServiceAccountToken.GRAFANA_SA_PREFIX) + token_key = token_string[prefix_length : prefix_length + auth_token_constants.TOKEN_KEY_LENGTH] + hashable_token = binascii.hexlify(token_string.encode()).decode() + digest = hash_token_string(hashable_token) + return ServiceAccountToken.objects.create( + service_account=service_account, + token_key=token_key, + digest=digest, + ) + + return _make_token_for_service_account + + @pytest.fixture def make_token_for_organization(): def _make_token_for_organization(organization): diff --git a/engine/engine/middlewares.py b/engine/engine/middlewares.py index c3da3c4c2b..0173323bc0 100644 --- a/engine/engine/middlewares.py +++ b/engine/engine/middlewares.py @@ -28,9 +28,13 @@ def log_message(request, response, tag, message=""): ) if hasattr(request, "user") and request.user and request.user.id and hasattr(request.user, "organization"): user_id = request.user.id + if hasattr(request.user, "service_account"): + message += f"service_account_id={user_id} " + else: + message += f"user_id={user_id} " org_id = request.user.organization.id org_slug = request.user.organization.org_slug - message += f"user_id={user_id} org_id={org_id} org_slug={org_slug} " + message += f"org_id={org_id} org_slug={org_slug} " if request.path.startswith("/integrations/v1"): split_path = request.path.split("/") integration_type = split_path[3] From 1bd30b3cf8e3bd6804da7ae323c73e85781aa8b4 Mon Sep 17 00:00:00 2001 From: Joey Orlando Date: Tue, 19 Nov 2024 14:23:48 -0500 Subject: [PATCH 09/12] chore: remove deprecated `AlertGroupPostMortem` model + recently refactored/deprecated slack channel related columns (#5240) # What this PR does - `AlertGroupPostMortem` has no references in the codebase.. I stumbled across it while working on https://github.com/grafana/oncall/pull/5224 and decided to just remove it - Removing old Slack channel related `VARCHAR` columns; these were refactored to foreign key references to `slack_slackchannel` table in following PRs: - https://github.com/grafana/oncall/pull/5224 - https://github.com/grafana/oncall/pull/5199 - https://github.com/grafana/oncall/pull/5191 ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- .../migrations/0001_squashed_initial.py | 2 +- ...hannelfilter__slack_channel_id_and_more.py | 26 ++++++++++ engine/apps/alerts/models/channel_filter.py | 4 -- engine/apps/alerts/models/resolution_note.py | 50 +++---------------- .../0020_remove_oncallschedule_channel.py | 19 +++++++ .../apps/schedules/models/on_call_schedule.py | 2 - ...ove_organization_general_log_channel_id.py | 19 +++++++ .../user_management/models/organization.py | 3 -- 8 files changed, 73 insertions(+), 52 deletions(-) create mode 100644 engine/apps/alerts/migrations/0066_remove_channelfilter__slack_channel_id_and_more.py create mode 100644 engine/apps/schedules/migrations/0020_remove_oncallschedule_channel.py create mode 100644 engine/apps/user_management/migrations/0028_remove_organization_general_log_channel_id.py diff --git a/engine/apps/alerts/migrations/0001_squashed_initial.py b/engine/apps/alerts/migrations/0001_squashed_initial.py index 0c96d7d4ad..8426d2635a 100644 --- a/engine/apps/alerts/migrations/0001_squashed_initial.py +++ b/engine/apps/alerts/migrations/0001_squashed_initial.py @@ -119,7 +119,7 @@ class Migration(migrations.Migration): name='AlertGroupPostmortem', fields=[ ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), - ('public_primary_key', models.CharField(default=apps.alerts.models.resolution_note.generate_public_primary_key_for_alert_group_postmortem, max_length=20, unique=True, validators=[django.core.validators.MinLengthValidator(13)])), + ('public_primary_key', models.CharField(max_length=20, unique=True, validators=[django.core.validators.MinLengthValidator(13)])), ('created_at', models.DateTimeField(auto_now_add=True)), ('last_modified', models.DateTimeField(auto_now=True)), ('text', models.TextField(default=None, max_length=3000, null=True)), diff --git a/engine/apps/alerts/migrations/0066_remove_channelfilter__slack_channel_id_and_more.py b/engine/apps/alerts/migrations/0066_remove_channelfilter__slack_channel_id_and_more.py new file mode 100644 index 0000000000..03c5f53430 --- /dev/null +++ b/engine/apps/alerts/migrations/0066_remove_channelfilter__slack_channel_id_and_more.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.16 on 2024-11-06 21:11 + +from django.db import migrations +import django_migration_linter as linter + + +class Migration(migrations.Migration): + + dependencies = [ + ('alerts', '0065_alertreceivechannel_service_account'), + ] + + operations = [ + linter.IgnoreMigration(), + migrations.RemoveField( + model_name='channelfilter', + name='_slack_channel_id', + ), + migrations.RemoveField( + model_name='resolutionnoteslackmessage', + name='_slack_channel_id', + ), + migrations.DeleteModel( + name='AlertGroupPostmortem', + ), + ] diff --git a/engine/apps/alerts/models/channel_filter.py b/engine/apps/alerts/models/channel_filter.py index f7cb302f7a..3ea2ea8bcb 100644 --- a/engine/apps/alerts/models/channel_filter.py +++ b/engine/apps/alerts/models/channel_filter.py @@ -69,9 +69,6 @@ class ChannelFilter(OrderedModel): notify_in_slack = models.BooleanField(null=True, default=True) notify_in_telegram = models.BooleanField(null=True, default=False) - - # TODO: remove _slack_channel_id in future release - _slack_channel_id = models.CharField(max_length=100, null=True, default=None) slack_channel = models.ForeignKey( "slack.SlackChannel", null=True, @@ -79,7 +76,6 @@ class ChannelFilter(OrderedModel): on_delete=models.SET_NULL, related_name="+", ) - telegram_channel = models.ForeignKey( "telegram.TelegramToOrganizationConnector", on_delete=models.SET_NULL, diff --git a/engine/apps/alerts/models/resolution_note.py b/engine/apps/alerts/models/resolution_note.py index e2f3586a55..90e651662a 100644 --- a/engine/apps/alerts/models/resolution_note.py +++ b/engine/apps/alerts/models/resolution_note.py @@ -14,20 +14,7 @@ if typing.TYPE_CHECKING: from apps.alerts.models import AlertGroup from apps.slack.models import SlackChannel - - -def generate_public_primary_key_for_alert_group_postmortem(): - prefix = "P" - new_public_primary_key = generate_public_primary_key(prefix) - - failure_counter = 0 - while AlertGroupPostmortem.objects.filter(public_primary_key=new_public_primary_key).exists(): - new_public_primary_key = increase_public_primary_key_length( - failure_counter=failure_counter, prefix=prefix, model_name="AlertGroupPostmortem" - ) - failure_counter += 1 - - return new_public_primary_key + from apps.user_management.models import User def generate_public_primary_key_for_resolution_note(): @@ -75,9 +62,6 @@ class ResolutionNoteSlackMessage(models.Model): related_name="added_resolution_note_slack_messages", ) text = models.TextField(max_length=3000, default=None, null=True) - - # TODO: remove _slack_channel_id in future release - _slack_channel_id = models.CharField(max_length=100, null=True, default=None) slack_channel = models.ForeignKey( "slack.SlackChannel", null=True, @@ -85,7 +69,6 @@ class ResolutionNoteSlackMessage(models.Model): on_delete=models.SET_NULL, related_name="+", ) - ts = models.CharField(max_length=100, null=True, default=None) thread_ts = models.CharField(max_length=100, null=True, default=None) permalink = models.CharField(max_length=250, null=True, default=None) @@ -130,6 +113,7 @@ def filter(self, *args, **kwargs): class ResolutionNote(models.Model): alert_group: "AlertGroup" + author: typing.Optional["User"] resolution_note_slack_message: typing.Optional[ResolutionNoteSlackMessage] objects = ResolutionNoteQueryset.as_manager() @@ -213,29 +197,11 @@ def render_log_line_json(self): return result - def author_verbal(self, mention): - """ - Postmortems to resolution notes included migrating AlertGroupPostmortem to ResolutionNotes. - But AlertGroupPostmortem has no author field. So this method was introduces as workaround. + def author_verbal(self, mention: bool) -> str: """ - if self.author is not None: - return self.author.get_username_with_slack_verbal(mention) - else: - return "" + Postmortems to resolution notes included migrating `AlertGroupPostmortem` to `ResolutionNote`s. + But `AlertGroupPostmortem` has no author field. So this method was introduced as a workaround. - -class AlertGroupPostmortem(models.Model): - public_primary_key = models.CharField( - max_length=20, - validators=[MinLengthValidator(settings.PUBLIC_PRIMARY_KEY_MIN_LENGTH + 1)], - unique=True, - default=generate_public_primary_key_for_alert_group_postmortem, - ) - alert_group = models.ForeignKey( - "alerts.AlertGroup", - on_delete=models.CASCADE, - related_name="postmortem_text", - ) - created_at = models.DateTimeField(auto_now_add=True) - last_modified = models.DateTimeField(auto_now=True) - text = models.TextField(max_length=3000, default=None, null=True) + (see git history for more details on what `AlertGroupPostmortem` was) + """ + return "" if self.author is None else self.author.get_username_with_slack_verbal(mention) diff --git a/engine/apps/schedules/migrations/0020_remove_oncallschedule_channel.py b/engine/apps/schedules/migrations/0020_remove_oncallschedule_channel.py new file mode 100644 index 0000000000..e4d1913827 --- /dev/null +++ b/engine/apps/schedules/migrations/0020_remove_oncallschedule_channel.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.16 on 2024-11-06 21:13 + +from django.db import migrations +import django_migration_linter as linter + + +class Migration(migrations.Migration): + + dependencies = [ + ('schedules', '0019_auto_20241021_1735'), + ] + + operations = [ + linter.IgnoreMigration(), + migrations.RemoveField( + model_name='oncallschedule', + name='channel', + ), + ] diff --git a/engine/apps/schedules/models/on_call_schedule.py b/engine/apps/schedules/models/on_call_schedule.py index 544ec847b2..e57cf4bc48 100644 --- a/engine/apps/schedules/models/on_call_schedule.py +++ b/engine/apps/schedules/models/on_call_schedule.py @@ -209,8 +209,6 @@ class OnCallSchedule(PolymorphicModel): name = models.CharField(max_length=200) - # TODO: drop this field in a subsequent release, this has been migrated to slack_channel field - channel = models.CharField(max_length=100, null=True, default=None) slack_channel = models.ForeignKey( "slack.SlackChannel", null=True, diff --git a/engine/apps/user_management/migrations/0028_remove_organization_general_log_channel_id.py b/engine/apps/user_management/migrations/0028_remove_organization_general_log_channel_id.py new file mode 100644 index 0000000000..6d415bdb44 --- /dev/null +++ b/engine/apps/user_management/migrations/0028_remove_organization_general_log_channel_id.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.16 on 2024-11-06 21:11 + +from django.db import migrations +import django_migration_linter as linter + + +class Migration(migrations.Migration): + + dependencies = [ + ('user_management', '0027_serviceaccount'), + ] + + operations = [ + linter.IgnoreMigration(), + migrations.RemoveField( + model_name='organization', + name='general_log_channel_id', + ), + ] diff --git a/engine/apps/user_management/models/organization.py b/engine/apps/user_management/models/organization.py index aac0aeae9a..2fbeefca1d 100644 --- a/engine/apps/user_management/models/organization.py +++ b/engine/apps/user_management/models/organization.py @@ -162,9 +162,6 @@ class Organization(MaintainableObject): slack_team_identity = models.ForeignKey( "slack.SlackTeamIdentity", on_delete=models.PROTECT, null=True, default=None, related_name="organizations" ) - - # TODO: drop this field in a subsequent release, this has been migrated to default_slack_channel field - general_log_channel_id = models.CharField(max_length=100, null=True, default=None) default_slack_channel = models.ForeignKey( "slack.SlackChannel", null=True, From 2024ee7f78aee69e7b7a4d7371c31cb593b8f3ee Mon Sep 17 00:00:00 2001 From: Michael Derynck Date: Tue, 19 Nov 2024 15:23:15 -0700 Subject: [PATCH 10/12] feat: Auto retry escalation on failed audit (#5265) # What this PR does Automatically retries escalation when alert groups fail auditing. This is the same effect as the continue_escalation command without any of the extra arguments. ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- .../alerts/tasks/check_escalation_finished.py | 41 +++++- .../test_check_escalation_finished_task.py | 123 ++++++++++++++++++ engine/settings/base.py | 2 + 3 files changed, 165 insertions(+), 1 deletion(-) diff --git a/engine/apps/alerts/tasks/check_escalation_finished.py b/engine/apps/alerts/tasks/check_escalation_finished.py index 9f3fb62d8c..8ae6d8146c 100644 --- a/engine/apps/alerts/tasks/check_escalation_finished.py +++ b/engine/apps/alerts/tasks/check_escalation_finished.py @@ -2,7 +2,9 @@ import typing import requests +from celery import uuid as celery_uuid from django.conf import settings +from django.core.cache import cache from django.db.models import Avg, F, Max, Q from django.utils import timezone @@ -174,6 +176,42 @@ def check_personal_notifications_task() -> None: task_logger.info(f"personal_notifications_triggered={triggered} personal_notifications_completed={completed}") +# Retries an alert group that has failed auditing if it is within the retry limit +# Returns whether an alert group escalation is being retried +def retry_audited_alert_group(alert_group) -> bool: + cache_key = f"audited-alert-group-retry-count-{alert_group.id}" + retry_count = cache.get(cache_key, 0) + if retry_count >= settings.AUDITED_ALERT_GROUP_MAX_RETRIES: + task_logger.info(f"Not retrying audited alert_group={alert_group.id} max retries exceeded.") + return False + + if alert_group.is_silenced_for_period: + task_logger.info(f"Not retrying audited alert_group={alert_group.id} as it is silenced.") + return False + + if not alert_group.escalation_snapshot: + task_logger.info(f"Not retrying audited alert_group={alert_group.id} as its escalation snapshot is empty.") + return False + + retry_count += 1 + cache.set(cache_key, retry_count, timeout=3600) + + task_id = celery_uuid() + alert_group.active_escalation_id = task_id + alert_group.save(update_fields=["active_escalation_id"]) + + from apps.alerts.tasks import escalate_alert_group + + escalate_alert_group.apply_async( + args=(alert_group.pk,), + immutable=True, + task_id=task_id, + eta=alert_group.next_step_eta, + ) + task_logger.info(f"Retrying audited alert_group={alert_group.id} attempt={retry_count}") + return True + + @shared_log_exception_on_failure_task def check_escalation_finished_task() -> None: """ @@ -221,7 +259,8 @@ def check_escalation_finished_task() -> None: try: audit_alert_group_escalation(alert_group) except AlertGroupEscalationPolicyExecutionAuditException: - alert_group_ids_that_failed_audit.append(str(alert_group.id)) + if not retry_audited_alert_group(alert_group): + alert_group_ids_that_failed_audit.append(str(alert_group.id)) failed_alert_groups_count = len(alert_group_ids_that_failed_audit) success_ratio = ( diff --git a/engine/apps/alerts/tests/test_check_escalation_finished_task.py b/engine/apps/alerts/tests/test_check_escalation_finished_task.py index 8aa5cbbdd9..229fabff49 100644 --- a/engine/apps/alerts/tests/test_check_escalation_finished_task.py +++ b/engine/apps/alerts/tests/test_check_escalation_finished_task.py @@ -6,12 +6,14 @@ from django.utils import timezone from apps.alerts.models import EscalationPolicy +from apps.alerts.tasks import escalate_alert_group from apps.alerts.tasks.check_escalation_finished import ( AlertGroupEscalationPolicyExecutionAuditException, audit_alert_group_escalation, check_alert_group_personal_notifications_task, check_escalation_finished_task, check_personal_notifications_task, + retry_audited_alert_group, send_alert_group_escalation_auditor_task_heartbeat, ) from apps.base.models import UserNotificationPolicy, UserNotificationPolicyLogRecord @@ -580,3 +582,124 @@ def test_check_escalation_finished_task_calls_audit_alert_group_personal_notific check_personal_notifications_task() assert "personal_notifications_triggered=6 personal_notifications_completed=2" in caplog.text + + +@patch("apps.alerts.tasks.check_escalation_finished.audit_alert_group_escalation") +@patch("apps.alerts.tasks.check_escalation_finished.retry_audited_alert_group") +@patch("apps.alerts.tasks.check_escalation_finished.send_alert_group_escalation_auditor_task_heartbeat") +@pytest.mark.django_db +def test_invoke_retry_from_check_escalation_finished_task( + mocked_send_alert_group_escalation_auditor_task_heartbeat, + mocked_retry_audited_alert_group, + mocked_audit_alert_group_escalation, + make_organization_and_user, + make_alert_receive_channel, + make_alert_group_that_started_at_specific_date, +): + organization, _ = make_organization_and_user() + alert_receive_channel = make_alert_receive_channel(organization) + + # Pass audit (should not be counted in final message or go to retry function) + alert_group1 = make_alert_group_that_started_at_specific_date(alert_receive_channel, received_delta=1) + # Fail audit but not retrying (should be counted in final message) + alert_group2 = make_alert_group_that_started_at_specific_date(alert_receive_channel, received_delta=5) + # Fail audit but retry (should not be counted in final message) + alert_group3 = make_alert_group_that_started_at_specific_date(alert_receive_channel, received_delta=10) + + def _mocked_audit_alert_group_escalation(alert_group): + if alert_group.id == alert_group2.id or alert_group.id == alert_group3.id: + raise AlertGroupEscalationPolicyExecutionAuditException(f"{alert_group2.id} failed audit") + + mocked_audit_alert_group_escalation.side_effect = _mocked_audit_alert_group_escalation + + def _mocked_retry_audited_alert_group(alert_group): + if alert_group.id == alert_group2.id: + return False + return True + + mocked_retry_audited_alert_group.side_effect = _mocked_retry_audited_alert_group + + with pytest.raises(AlertGroupEscalationPolicyExecutionAuditException) as exc: + check_escalation_finished_task() + + error_msg = str(exc.value) + + assert "The following alert group id(s) failed auditing:" in error_msg + assert str(alert_group1.id) not in error_msg + assert str(alert_group2.id) in error_msg + assert str(alert_group3.id) not in error_msg + + assert mocked_retry_audited_alert_group.call_count == 2 + mocked_send_alert_group_escalation_auditor_task_heartbeat.assert_not_called() + + +@patch.object(escalate_alert_group, "apply_async") +@override_settings(AUDITED_ALERT_GROUP_MAX_RETRIES=1) +@pytest.mark.django_db +def test_retry_audited_alert_group( + mocked_escalate_alert_group, + make_organization_and_user, + make_user_for_organization, + make_user_notification_policy, + make_escalation_chain, + make_escalation_policy, + make_channel_filter, + make_alert_receive_channel, + make_alert_group_that_started_at_specific_date, +): + organization, user = make_organization_and_user() + make_user_notification_policy( + user=user, + step=UserNotificationPolicy.Step.NOTIFY, + notify_by=UserNotificationPolicy.NotificationChannel.SLACK, + ) + + alert_receive_channel = make_alert_receive_channel(organization) + escalation_chain = make_escalation_chain(organization) + channel_filter = make_channel_filter(alert_receive_channel, escalation_chain=escalation_chain) + notify_to_multiple_users_step = make_escalation_policy( + escalation_chain=channel_filter.escalation_chain, + escalation_policy_step=EscalationPolicy.STEP_NOTIFY_MULTIPLE_USERS, + ) + notify_to_multiple_users_step.notify_to_users_queue.set([user]) + + alert_group1 = make_alert_group_that_started_at_specific_date(alert_receive_channel, channel_filter=channel_filter) + alert_group1.raw_escalation_snapshot = alert_group1.build_raw_escalation_snapshot() + alert_group1.raw_escalation_snapshot["last_active_escalation_policy_order"] = 1 + alert_group1.save() + + # Retry should occur + is_retrying = retry_audited_alert_group(alert_group1) + assert is_retrying + mocked_escalate_alert_group.assert_called() + mocked_escalate_alert_group.reset_mock() + + # No retry as attempts == max + is_retrying = retry_audited_alert_group(alert_group1) + assert not is_retrying + mocked_escalate_alert_group.assert_not_called() + mocked_escalate_alert_group.reset_mock() + + alert_group2 = make_alert_group_that_started_at_specific_date(alert_receive_channel, channel_filter=channel_filter) + # No retry because no escalation snapshot + is_retrying = retry_audited_alert_group(alert_group2) + assert not is_retrying + mocked_escalate_alert_group.assert_not_called() + mocked_escalate_alert_group.reset_mock() + + alert_group3 = make_alert_group_that_started_at_specific_date( + alert_receive_channel, + channel_filter=channel_filter, + silenced=True, + silenced_at=timezone.now(), + silenced_by_user=user, + silenced_until=(now + timezone.timedelta(hours=1)), + ) + alert_group3.raw_escalation_snapshot = alert_group1.build_raw_escalation_snapshot() + alert_group3.raw_escalation_snapshot["last_active_escalation_policy_order"] = 1 + alert_group3.save() + + # No retry because alert group silenced + is_retrying = retry_audited_alert_group(alert_group3) + assert not is_retrying + mocked_escalate_alert_group.assert_not_called() diff --git a/engine/settings/base.py b/engine/settings/base.py index 25ef7dc142..2b3cc9714e 100644 --- a/engine/settings/base.py +++ b/engine/settings/base.py @@ -988,3 +988,5 @@ class BrokerTypes: SYNC_V2_MAX_TASKS = getenv_integer("SYNC_V2_MAX_TASKS", 6) SYNC_V2_PERIOD_SECONDS = getenv_integer("SYNC_V2_PERIOD_SECONDS", 240) SYNC_V2_BATCH_SIZE = getenv_integer("SYNC_V2_BATCH_SIZE", 500) + +AUDITED_ALERT_GROUP_MAX_RETRIES = getenv_integer("AUDITED_ALERT_GROUP_MAX_RETRIES", 1) From 336b924a0811e1209e6543ca5caaa769519958c0 Mon Sep 17 00:00:00 2001 From: Jack Baldry Date: Wed, 20 Nov 2024 10:05:03 +0000 Subject: [PATCH 11/12] Fix first heading level (#5269) --- docs/sources/configure/jinja2-templating/_index.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/sources/configure/jinja2-templating/_index.md b/docs/sources/configure/jinja2-templating/_index.md index 6883785877..6cc158f7f8 100644 --- a/docs/sources/configure/jinja2-templating/_index.md +++ b/docs/sources/configure/jinja2-templating/_index.md @@ -23,8 +23,7 @@ refs: destination: /docs/grafana-cloud/alerting-and-irm/oncall/configure/integrations/references/webhook/ --- - -## Configure templates +# Configure templates Grafana OnCall integrates with your monitoring systems using webhooks with JSON payloads. By default, these webhooks deliver raw JSON payloads. From fda05a6cc43b509b8aed651355a659e558a1aca9 Mon Sep 17 00:00:00 2001 From: Joey Orlando Date: Wed, 20 Nov 2024 11:17:04 -0500 Subject: [PATCH 12/12] chore: remove deprecated `slack_channel` and `heartbeat` integration types (#5270) # What this PR does See [Slack discussion](https://raintank-corp.slack.com/archives/C06K1MQ07GS/p1732110700877869) for more context ## Checklist - [x] Unit, integration, and e2e (if applicable) tests updated - [x] Documentation added (or `pr:no public docs` PR label added if not required) - [x] Added the relevant release notes label (see labels prefixed w/ `release:`). These labels dictate how your PR will show up in the autogenerated release notes. --- dev/helm-local.yml | 2 +- .../alerts/models/alert_receive_channel.py | 18 +++----- ...rtbeat_actual_check_up_task_id_and_more.py | 23 ++++++++++ engine/apps/heartbeat/models.py | 10 ----- engine/apps/heartbeat/tasks.py | 6 --- engine/apps/heartbeat/tests/factories.py | 2 - engine/apps/integrations/tasks.py | 5 +-- .../public_api/tests/test_integrations.py | 1 - .../apps/slack/alert_group_slack_service.py | 5 +-- .../apps/slack/scenarios/distribute_alerts.py | 18 +------- engine/config_integrations/heartbeat.py | 29 ------------ engine/config_integrations/slack_channel.py | 44 ------------------- engine/settings/base.py | 2 - engine/settings/celery_task_routes.py | 1 - 14 files changed, 33 insertions(+), 133 deletions(-) create mode 100644 engine/apps/heartbeat/migrations/0003_remove_integrationheartbeat_actual_check_up_task_id_and_more.py delete mode 100644 engine/config_integrations/heartbeat.py delete mode 100644 engine/config_integrations/slack_channel.py diff --git a/dev/helm-local.yml b/dev/helm-local.yml index 8655df43fd..770a5dfb0c 100644 --- a/dev/helm-local.yml +++ b/dev/helm-local.yml @@ -39,7 +39,7 @@ engine: replicaCount: 1 celery: replicaCount: 1 - worker_beat_enabled: false + worker_beat_enabled: true externalGrafana: url: http://grafana:3000 diff --git a/engine/apps/alerts/models/alert_receive_channel.py b/engine/apps/alerts/models/alert_receive_channel.py index a8cb1494d9..7a351d2aad 100644 --- a/engine/apps/alerts/models/alert_receive_channel.py +++ b/engine/apps/alerts/models/alert_receive_channel.py @@ -525,29 +525,21 @@ def short_name(self): ) @property - def short_name_with_maintenance_status(self): - if self.maintenance_mode is not None: - return ( - self.short_name + f" *[ on " - f"{AlertReceiveChannel.MAINTENANCE_MODE_CHOICES[self.maintenance_mode][1]}" - f" :construction: ]*" - ) - else: - return self.short_name - - @property - def created_name(self): + def created_name(self) -> str: return f"{self.get_integration_display()} {self.smile_code}" @property def web_link(self) -> str: return UIURLBuilder(self.organization).integration_detail(self.public_primary_key) + @property + def is_maintenace_integration(self) -> bool: + return self.integration == AlertReceiveChannel.INTEGRATION_MAINTENANCE + @property def integration_url(self) -> str | None: if self.integration in [ AlertReceiveChannel.INTEGRATION_MANUAL, - AlertReceiveChannel.INTEGRATION_SLACK_CHANNEL, AlertReceiveChannel.INTEGRATION_INBOUND_EMAIL, AlertReceiveChannel.INTEGRATION_MAINTENANCE, ]: diff --git a/engine/apps/heartbeat/migrations/0003_remove_integrationheartbeat_actual_check_up_task_id_and_more.py b/engine/apps/heartbeat/migrations/0003_remove_integrationheartbeat_actual_check_up_task_id_and_more.py new file mode 100644 index 0000000000..e50d915ee5 --- /dev/null +++ b/engine/apps/heartbeat/migrations/0003_remove_integrationheartbeat_actual_check_up_task_id_and_more.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.16 on 2024-11-20 15:39 + +from django.db import migrations +import django_migration_linter as linter + + +class Migration(migrations.Migration): + + dependencies = [ + ('heartbeat', '0002_delete_heartbeat'), + ] + + operations = [ + linter.IgnoreMigration(), + migrations.RemoveField( + model_name='integrationheartbeat', + name='actual_check_up_task_id', + ), + migrations.RemoveField( + model_name='integrationheartbeat', + name='last_checkup_task_time', + ), + ] diff --git a/engine/apps/heartbeat/models.py b/engine/apps/heartbeat/models.py index 0c0084bd15..4688cc716d 100644 --- a/engine/apps/heartbeat/models.py +++ b/engine/apps/heartbeat/models.py @@ -48,16 +48,6 @@ class IntegrationHeartBeat(models.Model): Stores the latest received heartbeat signal time """ - last_checkup_task_time = models.DateTimeField(default=None, null=True) - """ - Deprecated. This field is not used. TODO: remove it - """ - - actual_check_up_task_id = models.CharField(max_length=100) - """ - Deprecated. Stored the latest scheduled `integration_heartbeat_checkup` task id. TODO: remove it - """ - previous_alerted_state_was_life = models.BooleanField(default=True) """ Last status of the heartbeat. Determines if integration was alive on latest checkup diff --git a/engine/apps/heartbeat/tasks.py b/engine/apps/heartbeat/tasks.py index 7939290ec5..e9d26c578d 100644 --- a/engine/apps/heartbeat/tasks.py +++ b/engine/apps/heartbeat/tasks.py @@ -105,12 +105,6 @@ def _get_timeout_expression() -> ExpressionWrapper: return f"Found {expired_count} expired and {restored_count} restored heartbeats" -@shared_dedicated_queue_retry_task() -def integration_heartbeat_checkup(heartbeat_id: int) -> None: - """Deprecated. TODO: Remove this task after this task cleared from queue""" - pass - - @shared_dedicated_queue_retry_task() def process_heartbeat_task(alert_receive_channel_pk): IntegrationHeartBeat.objects.filter( diff --git a/engine/apps/heartbeat/tests/factories.py b/engine/apps/heartbeat/tests/factories.py index 5e69db9de9..40011255e3 100644 --- a/engine/apps/heartbeat/tests/factories.py +++ b/engine/apps/heartbeat/tests/factories.py @@ -4,7 +4,5 @@ class IntegrationHeartBeatFactory(factory.DjangoModelFactory): - actual_check_up_task_id = "none" - class Meta: model = IntegrationHeartBeat diff --git a/engine/apps/integrations/tasks.py b/engine/apps/integrations/tasks.py index 45f3e04f2a..91f6a7d416 100644 --- a/engine/apps/integrations/tasks.py +++ b/engine/apps/integrations/tasks.py @@ -31,10 +31,7 @@ def create_alertmanager_alerts(alert_receive_channel_pk, alert, is_demo=False, r from apps.alerts.models import Alert, AlertReceiveChannel alert_receive_channel = AlertReceiveChannel.objects_with_deleted.get(pk=alert_receive_channel_pk) - if ( - alert_receive_channel.deleted_at is not None - or alert_receive_channel.integration == AlertReceiveChannel.INTEGRATION_MAINTENANCE - ): + if alert_receive_channel.deleted_at is not None or alert_receive_channel.is_maintenace_integration: logger.info("AlertReceiveChannel alert ignored if deleted/maintenance") return diff --git a/engine/apps/public_api/tests/test_integrations.py b/engine/apps/public_api/tests/test_integrations.py index 796942eb59..9a4e29c64f 100644 --- a/engine/apps/public_api/tests/test_integrations.py +++ b/engine/apps/public_api/tests/test_integrations.py @@ -903,7 +903,6 @@ def test_get_list_integrations_link_and_inbound_email( if integration_type in [ AlertReceiveChannel.INTEGRATION_MANUAL, - AlertReceiveChannel.INTEGRATION_SLACK_CHANNEL, AlertReceiveChannel.INTEGRATION_MAINTENANCE, ]: assert integration_link is None diff --git a/engine/apps/slack/alert_group_slack_service.py b/engine/apps/slack/alert_group_slack_service.py index 9bb9510bde..ed614305f8 100644 --- a/engine/apps/slack/alert_group_slack_service.py +++ b/engine/apps/slack/alert_group_slack_service.py @@ -35,9 +35,8 @@ def __init__( self._slack_client = SlackClient(slack_team_identity) def update_alert_group_slack_message(self, alert_group: "AlertGroup") -> None: - from apps.alerts.models import AlertReceiveChannel - logger.info(f"Update message for alert_group {alert_group.pk}") + try: self._slack_client.chat_update( channel=alert_group.slack_message.channel_id, @@ -47,7 +46,7 @@ def update_alert_group_slack_message(self, alert_group: "AlertGroup") -> None: ) logger.info(f"Message has been updated for alert_group {alert_group.pk}") except SlackAPIRatelimitError as e: - if alert_group.channel.integration != AlertReceiveChannel.INTEGRATION_MAINTENANCE: + if not alert_group.channel.is_maintenace_integration: if not alert_group.channel.is_rate_limited_in_slack: alert_group.channel.start_send_rate_limit_message_task(e.retry_after) logger.info( diff --git a/engine/apps/slack/scenarios/distribute_alerts.py b/engine/apps/slack/scenarios/distribute_alerts.py index 3a7090e320..3d3c1a60a8 100644 --- a/engine/apps/slack/scenarios/distribute_alerts.py +++ b/engine/apps/slack/scenarios/distribute_alerts.py @@ -141,22 +141,6 @@ def _post_alert_group_to_slack( channel_id=channel_id, ) - # If alert was made out of a message: - if alert_group.channel.integration == AlertReceiveChannel.INTEGRATION_SLACK_CHANNEL: - channel = json.loads(alert.integration_unique_data)["channel"] - result = self._slack_client.chat_postMessage( - channel=channel, - thread_ts=json.loads(alert.integration_unique_data)["ts"], - text=":rocket: <{}|Incident registered!>".format(alert_group.slack_message.permalink), - team=slack_team_identity, - ) - alert_group.slack_messages.create( - slack_id=result["ts"], - organization=alert_group.channel.organization, - _slack_team_identity=self.slack_team_identity, - channel_id=channel, - ) - alert.delivered = True except SlackAPITokenError: alert_group.reason_to_skip_escalation = AlertGroup.ACCOUNT_INACTIVE @@ -172,7 +156,7 @@ def _post_alert_group_to_slack( logger.info("Not delivering alert due to channel is archived.") except SlackAPIRatelimitError as e: # don't rate limit maintenance alert - if alert_group.channel.integration != AlertReceiveChannel.INTEGRATION_MAINTENANCE: + if not alert_group.channel.is_maintenace_integration: alert_group.reason_to_skip_escalation = AlertGroup.RATE_LIMITED alert_group.save(update_fields=["reason_to_skip_escalation"]) alert_group.channel.start_send_rate_limit_message_task(e.retry_after) diff --git a/engine/config_integrations/heartbeat.py b/engine/config_integrations/heartbeat.py deleted file mode 100644 index 60699c4507..0000000000 --- a/engine/config_integrations/heartbeat.py +++ /dev/null @@ -1,29 +0,0 @@ -# Main -enabled = True -title = "Heartbeat" -slug = "heartbeat" -short_description = None -description = None -is_displayed_on_web = False -is_featured = False -is_able_to_autoresolve = True -is_demo_alert_enabled = False - -description = None - -# Default templates -slack_title = """\ -*<{{ grafana_oncall_link }}|#{{ grafana_oncall_incident_id }} {{ payload.get("title", "Title undefined (check Slack Title Template)") }}>* via {{ integration_name }} -{% if source_link %} - (*<{{ source_link }}|source>*) -{%- endif %}""" - -grouping_id = """\ -{{ payload.get("id", "") }}{{ payload.get("user_defined_id", "") }} -""" - -resolve_condition = '{{ payload.get("is_resolve", False) == True }}' - -acknowledge_condition = None - -example_payload = None diff --git a/engine/config_integrations/slack_channel.py b/engine/config_integrations/slack_channel.py deleted file mode 100644 index 05021935f1..0000000000 --- a/engine/config_integrations/slack_channel.py +++ /dev/null @@ -1,44 +0,0 @@ -# Main -enabled = True -title = "Slack Channel" -slug = "slack_channel" -short_description = None -description = None -is_displayed_on_web = False -is_featured = False -is_able_to_autoresolve = False -is_demo_alert_enabled = False - -description = None - -# Default templates -slack_title = """\ -{% if source_link -%} -*<{{ source_link }}|<#{{ payload.get("channel", "") }}>>* -{%- else -%} -<#{{ payload.get("channel", "") }}> -{%- endif %}""" - -web_title = """\ -{% if source_link -%} -[#{{ grafana_oncall_incident_id }}]{{ source_link }}) <#{{ payload.get("channel", "") }}>>* -{%- else -%} -*#{{ grafana_oncall_incident_id }}* <#{{ payload.get("channel", "") }}> -{%- endif %}""" - -telegram_title = """\ -{% if source_link -%} -#{{ grafana_oncall_incident_id }} {{ payload.get("channel", "") }} -{%- else -%} -*#{{ grafana_oncall_incident_id }}* <#{{ payload.get("channel", "") }}> -{%- endif %}""" - -grouping_id = '{{ payload.get("ts", "") }}' - -resolve_condition = None - -acknowledge_condition = None - -source_link = '{{ payload.get("amixr_mixin", {}).get("permalink", "")}}' - -example_payload = None diff --git a/engine/settings/base.py b/engine/settings/base.py index 2b3cc9714e..0f73c8d5af 100644 --- a/engine/settings/base.py +++ b/engine/settings/base.py @@ -878,11 +878,9 @@ class BrokerTypes: "config_integrations.formatted_webhook", "config_integrations.kapacitor", "config_integrations.elastalert", - "config_integrations.heartbeat", "config_integrations.inbound_email", "config_integrations.maintenance", "config_integrations.manual", - "config_integrations.slack_channel", "config_integrations.zabbix", "config_integrations.direct_paging", # Actually it's Grafana 8 integration. diff --git a/engine/settings/celery_task_routes.py b/engine/settings/celery_task_routes.py index 04a8ffa49a..7ef62121dd 100644 --- a/engine/settings/celery_task_routes.py +++ b/engine/settings/celery_task_routes.py @@ -12,7 +12,6 @@ "common.oncall_gateway.tasks.delete_oncall_connector_async": {"queue": "default"}, "common.oncall_gateway.tasks.create_slack_connector_async_v2": {"queue": "default"}, "common.oncall_gateway.tasks.delete_slack_connector_async_v2": {"queue": "default"}, - "apps.heartbeat.tasks.integration_heartbeat_checkup": {"queue": "default"}, "apps.heartbeat.tasks.process_heartbeat_task": {"queue": "default"}, "apps.labels.tasks.update_labels_cache": {"queue": "default"}, "apps.labels.tasks.update_instances_labels_cache": {"queue": "default"},