diff --git a/bot/adapters/discord/client.py b/bot/adapters/discord/client.py index 359ab00..fa5612f 100644 --- a/bot/adapters/discord/client.py +++ b/bot/adapters/discord/client.py @@ -77,6 +77,11 @@ def __init__(self, *, intents: discord.Intents, **options: typing.Any) -> None: description='Clears post cache for server', callback=self.clear_cache_cmd, ), + app_commands.Command( + name='provision', + description='Provisions a server', + callback=self.provision_cmd, + ), ] self.tree = app_commands.CommandTree(client=self) @@ -316,6 +321,35 @@ async def clear_cache_cmd( ephemeral=True, ) + @checks.has_permissions(administrator=True) + async def provision_cmd( + self, + interaction: discord.Interaction, + tier: constants.ServerTier, + integration: typing.Optional[constants.Integration] = None, + ) -> None: + await interaction.response.defer(ephemeral=True) + + service.provision_server( + server_vendor=constants.ServerVendor.DISCORD, + server_vendor_uid=str(interaction.guild_id), + tier=tier, + integrations=[integration] if integration else [], + ) + + logger.info( + 'Admin provisioned server', + tier=tier.value, + integration=integration.value if integration else 'all', + admin=interaction.user.id, + server_id=interaction.guild_id, + ) + + await interaction.followup.send( + content='Provisioned server', + ephemeral=True, + ) + async def _send_post( self, post: domain.Post, diff --git a/bot/management/commands/provision.py b/bot/management/commands/provision.py index e4333c9..c842e11 100644 --- a/bot/management/commands/provision.py +++ b/bot/management/commands/provision.py @@ -3,8 +3,7 @@ from django.core.management import base from bot import constants -from bot import logger -from bot import models +from bot import service class Command(base.BaseCommand): @@ -37,43 +36,19 @@ def add_arguments(self, parser: base.CommandParser) -> None: help='List of supported integrations', ) - def handle(self, *args: typing.Any, **options: typing.Any) -> typing.NoReturn: - server_vendor_uid = options.get('server_vendor_uid') + def handle(self, *args: typing.Any, **options: typing.Any) -> None: + server_vendor_uid = options.get('server_vendor_uid', '') server_vendor = options.get('server_vendor', constants.ServerVendor.DISCORD) integrations = self._parse_integrations(options.get('integrations', [])) tier = options.get('server_tier', constants.ServerTier.FREE) - logger.info( - 'Provisioning integrations for server', - integrations=[integration.value for integration in integrations] or 'all', - server_vendor=server_vendor.value, + service.provision_server( + server_vendor=server_vendor, server_vendor_uid=server_vendor_uid, + tier=tier, + integrations=integrations, ) - server = models.Server.objects.filter( - vendor=server_vendor, - vendor_uid=server_vendor_uid, - ).first() - if not server: - server = models.Server.objects.create( - vendor=server_vendor, - vendor_uid=server_vendor_uid, - tier=tier, - ) - - models.ServerIntegration.objects.bulk_create( - [ - models.ServerIntegration( - integration=integration, - server=server, - enabled=True, - ) - for integration in integrations - ] - ) - - logger.info('Successfully created server with integrations') - @staticmethod def _parse_integrations(integrations: typing.List[str]) -> typing.List[constants.Integration]: if not integrations: diff --git a/bot/service.py b/bot/service.py index 73c38c9..fd67642 100644 --- a/bot/service.py +++ b/bot/service.py @@ -1,10 +1,12 @@ import datetime import typing +from bot import cache from bot import constants from bot import domain from bot import exceptions from bot import logger +from bot import models from bot import repository from bot.integrations import registry @@ -214,3 +216,43 @@ async def get_comments( # noqa: C901 except Exception as e: logger.error('Failed downloading', url=url, num_comments=n, error=str(e)) raise e + + +def provision_server( + server_vendor: constants.ServerVendor, + server_vendor_uid: str, + tier: constants.ServerTier, + integrations: typing.List[constants.Integration], +) -> None: + logger.info( + 'Provisioning integrations for server', + integrations=[integration.value for integration in integrations] or 'all', + server_vendor=server_vendor.value, + server_vendor_uid=server_vendor_uid, + ) + + server = models.Server.objects.filter( + vendor=server_vendor, + vendor_uid=server_vendor_uid, + ).first() + if not server: + server = models.Server.objects.create( + vendor=server_vendor, + vendor_uid=server_vendor_uid, + tier=tier, + ) + + models.ServerIntegration.objects.bulk_create( + [ + models.ServerIntegration( + integration=integration, + server=server, + enabled=True, + ) + for integration in integrations or list(constants.Integration) + ] + ) + + cache.delete(store=cache.Store.SERVER, key=f'{server_vendor.value}_{server_vendor_uid}') + + logger.info('Successfully created server with integrations') diff --git a/docker/db/init.sql b/docker/db/init.sql deleted file mode 100644 index 28b5a8a..0000000 --- a/docker/db/init.sql +++ /dev/null @@ -1,11 +0,0 @@ --- Database -CREATE DATABASE embed_bot; - --- User -CREATE USER bot WITH ENCRYPTED PASSWORD 'bot'; -GRANT ALL PRIVILEGES ON DATABASE embed_bot TO bot; - --- Django conf -ALTER ROLE bot SET client_encoding TO 'utf8'; -ALTER ROLE bot SET default_transaction_isolation TO 'read committed'; -ALTER ROLE bot SET timezone TO 'UTC';