Skip to content

Commit

Permalink
Added permissions syncing for slack (#2602)
Browse files Browse the repository at this point in the history
* Added permissions syncing for slack

* add no email case handling

* mypy fixes

* frontend

* minor cleanup

* param tweak
  • Loading branch information
hagen-danswer authored Sep 30, 2024
1 parent 728a41a commit b005690
Show file tree
Hide file tree
Showing 11 changed files with 457 additions and 84 deletions.
1 change: 1 addition & 0 deletions backend/danswer/connectors/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def identify_connector_class(
DocumentSource.SLACK: {
InputType.LOAD_STATE: SlackLoadConnector,
InputType.POLL: SlackPollConnector,
InputType.PRUNE: SlackPollConnector,
},
DocumentSource.GITHUB: GithubConnector,
DocumentSource.GMAIL: GmailConnector,
Expand Down
133 changes: 84 additions & 49 deletions backend/danswer/connectors/slack/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@

from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse

from danswer.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
Expand All @@ -23,9 +22,8 @@
from danswer.connectors.models import Section
from danswer.connectors.slack.utils import expert_info_from_slack_id
from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.slack.utils import make_slack_api_call_logged
from danswer.connectors.slack.utils import make_slack_api_call_paginated
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from danswer.connectors.slack.utils import make_slack_api_call_w_retries
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.utils.logger import setup_logger

Expand All @@ -38,47 +36,18 @@
# list of messages in a thread
ThreadType = list[MessageType]

basic_retry_wrapper = retry_builder()


def _make_paginated_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)
)(**kwargs)


def _make_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(make_slack_api_call_logged(call))
)(**kwargs)


def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
"""Get information about a channel. Needed to convert channel ID to channel name"""
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
"channel"
]


def _get_channels(
def _collect_paginated_channels(
client: WebClient,
exclude_archived: bool,
get_private: bool,
channel_types: list[str],
) -> list[ChannelType]:
channels: list[dict[str, Any]] = []
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_list,
exclude_archived=exclude_archived,
# also get private channels the bot is added to
types=["public_channel", "private_channel"]
if get_private
else ["public_channel"],
types=channel_types,
):
channels.extend(result["channels"])

Expand All @@ -88,19 +57,38 @@ def _get_channels(
def get_channels(
client: WebClient,
exclude_archived: bool = True,
get_public: bool = True,
get_private: bool = True,
) -> list[ChannelType]:
"""Get all channels in the workspace"""
channels: list[dict[str, Any]] = []
channel_types = []
if get_public:
channel_types.append("public_channel")
if get_private:
channel_types.append("private_channel")
# try getting private channels as well at first
try:
return _get_channels(
client=client, exclude_archived=exclude_archived, get_private=True
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
except SlackApiError as e:
logger.info(f"Unable to fetch private channels due to - {e}")
logger.info("trying again without private channels")
if get_public:
channel_types = ["public_channel"]
else:
logger.warning("No channels to fetch")
return []
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)

return _get_channels(
client=client, exclude_archived=exclude_archived, get_private=False
)
return channels


def get_channel_messages(
Expand All @@ -112,14 +100,14 @@ def get_channel_messages(
"""Get all messages in a channel"""
# join so that the bot can access messages
if not channel["is_member"]:
_make_slack_api_call(
make_slack_api_call_w_retries(
client.conversations_join,
channel=channel["id"],
is_private=channel["is_private"],
)
logger.info(f"Successfully joined '{channel['name']}'")

for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
Expand All @@ -131,7 +119,7 @@ def get_channel_messages(
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
"""Get all messages in a thread"""
threads: list[MessageType] = []
for result in _make_paginated_slack_api_call(
for result in make_paginated_slack_api_call_w_retries(
client.conversations_replies, channel=channel_id, ts=thread_id
):
threads.extend(result["messages"])
Expand Down Expand Up @@ -266,7 +254,7 @@ def filter_channels(
]


def get_all_docs(
def _get_all_docs(
client: WebClient,
workspace: str,
channels: list[str] | None = None,
Expand Down Expand Up @@ -328,7 +316,44 @@ def get_all_docs(
)


class SlackPollConnector(PollConnector):
def _get_all_doc_ids(
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
) -> set[str]:
"""
Get all document ids in the workspace, channel by channel
This is pretty identical to get_all_docs, but it returns a set of ids instead of documents
This makes it an order of magnitude faster than get_all_docs
"""

all_channels = get_channels(client)
filtered_channels = filter_channels(
all_channels, channels, channel_name_regex_enabled
)

all_doc_ids = set()
for channel in filtered_channels:
channel_message_batches = get_channel_messages(
client=client,
channel=channel,
)

for message_batch in channel_message_batches:
for message in message_batch:
if msg_filter_func(message):
continue

# The document id is the channel id and the ts of the first message in the thread
# Since we already have the first message of the thread, we dont have to
# fetch the thread for id retrieval, saving time and API calls
all_doc_ids.add(f"{channel['id']}__{message['ts']}")

return all_doc_ids


class SlackPollConnector(PollConnector, IdConnector):
def __init__(
self,
workspace: str,
Expand All @@ -349,14 +374,24 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None
self.client = WebClient(token=bot_token)
return None

def retrieve_all_source_ids(self) -> set[str]:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")

return _get_all_doc_ids(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
)

def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.client is None:
raise ConnectorMissingCredentialError("Slack")

documents: list[Document] = []
for document in get_all_docs(
for document in _get_all_docs(
client=self.client,
workspace=self.workspace,
channels=self.channels,
Expand Down
24 changes: 22 additions & 2 deletions backend/danswer/connectors/slack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse

from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.logger import setup_logger

logger = setup_logger()

basic_retry_wrapper = retry_builder()
# number of messages we request per page when fetching paginated slack messages
_SLACK_LIMIT = 900

Expand All @@ -34,7 +36,7 @@ def get_message_link(
)


def make_slack_api_call_logged(
def _make_slack_api_call_logged(
call: Callable[..., SlackResponse],
) -> Callable[..., SlackResponse]:
@wraps(call)
Expand All @@ -47,7 +49,7 @@ def logged_call(**kwargs: Any) -> SlackResponse:
return logged_call


def make_slack_api_call_paginated(
def _make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., Generator[dict[str, Any], None, None]]:
"""Wraps calls to slack API so that they automatically handle pagination"""
Expand Down Expand Up @@ -116,6 +118,24 @@ def rate_limited_call(**kwargs: Any) -> SlackResponse:
return rate_limited_call


def make_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)(**kwargs)


def make_paginated_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return _make_slack_api_call_paginated(
basic_retry_wrapper(
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
)
)(**kwargs)


def expert_info_from_slack_id(
user_id: str | None,
client: WebClient,
Expand Down
4 changes: 1 addition & 3 deletions backend/danswer/db/connector_credential_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.permission_sync_function_map import (
check_if_valid_sync_source,
)
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source

logger = setup_logger()

Expand Down
Loading

0 comments on commit b005690

Please sign in to comment.