From f1df14e2298169d3371c9fddb9621b0a593c7b6b Mon Sep 17 00:00:00 2001 From: Anthony LC Date: Fri, 20 Sep 2024 22:42:46 +0200 Subject: [PATCH 1/6] =?UTF-8?q?=E2=9C=A8(backend)=20create=20ai=20endpoint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We created a new endpoint to perform AI operations. POST /api/v1.0/ai/ with expected data: - text: str - action: str [prompt, correct, rephrase, summarize, translate_en, translate_de, translate_fr] Return JSON response with the processed text. --- env.d/development/common.dist | 4 + src/backend/core/api/viewsets.py | 124 +++++++++++++++++ src/backend/core/tests/test_api_ai.py | 191 ++++++++++++++++++++++++++ src/backend/core/urls.py | 4 + src/backend/impress/settings.py | 3 + src/backend/pyproject.toml | 1 + 6 files changed, 327 insertions(+) create mode 100644 src/backend/core/tests/test_api_ai.py diff --git a/env.d/development/common.dist b/env.d/development/common.dist index 712c4dbd1..e9a4ba685 100644 --- a/env.d/development/common.dist +++ b/env.d/development/common.dist @@ -39,3 +39,7 @@ LOGOUT_REDIRECT_URL=http://localhost:3000 OIDC_REDIRECT_ALLOWED_HOSTS=["http://localhost:8083", "http://localhost:3000"] OIDC_AUTH_REQUEST_EXTRA_PARAMS={"acr_values": "eidas1"} + +AI_BASE_URL=https://openaiendpoint.com +AI_API_KEY=password +AI_MODEL=llama diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index e3a0003db..7450107fa 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -1,5 +1,6 @@ """API endpoints""" +import json import os import re import uuid @@ -18,6 +19,7 @@ from django.http import Http404 from botocore.exceptions import ClientError +from openai import OpenAI from rest_framework import ( decorators, exceptions, @@ -781,3 +783,125 @@ def perform_create(self, serializer): invitation.document.email_invitation( language, invitation.email, invitation.role, self.request.user.email ) + + +class AIViewSet(viewsets.ViewSet): + """API ViewSet for handling AI tasks""" + + permission_classes = [permissions.IsAuthenticated] + + def create(self, request): + """ + POST /api/v1.0/ai/ with expected data: + - text: str + - action: str [prompt, correct, rephrase, summarize, + translate_en, translate_de, translate_fr] + Return JSON response with the processed text. + """ + if not request.user.is_authenticated: + raise exceptions.NotAuthenticated() + + if ( + settings.AI_BASE_URL is None + or settings.AI_API_KEY is None + or settings.AI_MODEL is None + ): + raise exceptions.ValidationError({"error": "AI configuration not set"}) + + action = request.data.get("action") + text = request.data.get("text") + + action_configs = { + "prompt": { + "system_content": ( + "Answer the prompt in markdown format. Return JSON: " + '{"answer": "Your markdown answer"}.' + "Do not provide any other information." + ), + "response_key": "answer", + }, + "correct": { + "system_content": ( + "Correct grammar and spelling of the markdown text, " + "preserving language and markdown formatting. " + 'Return JSON: {"answer": "your corrected markdown text"}.' + "Do not provide any other information." + ), + "response_key": "answer", + }, + "rephrase": { + "system_content": ( + "Rephrase the given markdown text, " + "preserving language and markdown formatting. " + 'Return JSON: {"answer": "your rephrased markdown text"}.' + "Do not provide any other information." + ), + "response_key": "answer", + }, + "summarize": { + "system_content": ( + "Summarize the markdown text, preserving language and markdown formatting. " + 'Return JSON: {"answer": "your markdown summary"}.' + "Do not provide any other information." + ), + "response_key": "answer", + }, + "translate_en": { + "system_content": ( + "Translate the markdown text to English, preserving markdown formatting. " + 'Return JSON: {"answer": "Your translated markdown text in English"}.' + "Do not provide any other information." + ), + "response_key": "answer", + }, + "translate_de": { + "system_content": ( + "Translate the markdown text to German, preserving markdown formatting. " + 'Return JSON: {"answer": "Your translated markdown text in German"}.' + "Do not provide any other information." + ), + "response_key": "answer", + }, + "translate_fr": { + "system_content": ( + "Translate the markdown text to French, preserving markdown formatting. " + 'Return JSON: {"answer": "Your translated markdown text in French"}.' + "Do not provide any other information." + ), + "response_key": "answer", + }, + } + + if action not in action_configs: + raise exceptions.ValidationError({"error": "Invalid action"}) + + config = action_configs[action] + + try: + client = OpenAI(base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY) + response = client.chat.completions.create( + model=settings.AI_MODEL, + response_format={"type": "json_object"}, + messages=[ + {"role": "system", "content": config["system_content"]}, + {"role": "user", "content": json.dumps({"mardown_input": text})}, + ], + ) + + corrected_response = json.loads(response.choices[0].message.content) + + if "answer" not in corrected_response: + raise exceptions.ValidationError("Invalid response format") + + return drf_response.Response(corrected_response, status=status.HTTP_200_OK) + + except exceptions.ValidationError as e: + return drf_response.Response( + {"error": e.detail}, status=status.HTTP_400_BAD_REQUEST + ) + + except exceptions.APIException as e: + return drf_response.Response( + {"error": f"Error processing AI response: {str(e)}"}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) diff --git a/src/backend/core/tests/test_api_ai.py b/src/backend/core/tests/test_api_ai.py new file mode 100644 index 000000000..57aa608a3 --- /dev/null +++ b/src/backend/core/tests/test_api_ai.py @@ -0,0 +1,191 @@ +""" +Test ai API endpoints in the impress core app. +""" + +from typing import List +from unittest.mock import patch + +from django.test.utils import override_settings + +import pytest +from pydantic import BaseModel +from rest_framework.exceptions import APIException +from rest_framework.test import APIClient + +from core import factories + +pytestmark = pytest.mark.django_db + + +class MessageMock(BaseModel): + """Message mock""" + + content: str + + +class ChoiceMock(BaseModel): + """Choice mock""" + + message: MessageMock + + +class ChatCompletionMock(BaseModel): + """ChatCompletion mock""" + + id: str + choices: List[ChoiceMock] + + +def test_api_ai__unauthentified(): + """Unauthentified users should not be allowed""" + + client = APIClient() + response = client.post("/api/v1.0/ai/") + + assert response.status_code == 401 + assert response.json() == { + "detail": "Authentication credentials were not provided." + } + + +@pytest.mark.parametrize( + "setting_name, setting_value", + [ + ("AI_BASE_URL", None), + ("AI_API_KEY", None), + ("AI_MODEL", None), + ], +) +def test_api_ai_setting_missing(setting_name, setting_value): + """Setting should be set""" + + user = factories.UserFactory() + + client = APIClient() + client.force_login(user) + + with override_settings(**{setting_name: setting_value}): + response = client.post("/api/v1.0/ai/") + + assert response.status_code == 400 + assert response.json() == {"error": "AI configuration not set"} + + +@override_settings( + AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" +) +def test_api_ai__bad_action_config(): + """ + Action config should raised when the action is not correct + """ + user = factories.UserFactory() + + client = APIClient() + client.force_login(user) + response = client.post( + "/api/v1.0/ai/", + { + "action": "bad_action", + "text": "Hello world", + }, + ) + + assert response.status_code == 400 + assert response.json() == {"error": "Invalid action"} + + +@override_settings( + AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" +) +def test_api_ai__client_error(): + """ + Fail when the client raises an error + """ + user = factories.UserFactory() + + client = APIClient() + client.force_login(user) + + with patch("openai.resources.chat.completions.Completions.create") as mock_create: + mock_create.side_effect = APIException("Mocked client error") + + response = client.post( + "/api/v1.0/ai/", + { + "action": "translate_fr", + "text": "Hello world", + }, + ) + + assert response.status_code == 500 + assert response.json() == { + "error": "Error processing AI response: Mocked client error" + } + + +@override_settings( + AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" +) +def test_api_ai__client_invalid_response(): + """ + Fail when the client response is invalid + """ + user = factories.UserFactory() + + client = APIClient() + client.force_login(user) + + with patch("openai.resources.chat.completions.Completions.create") as mock_create: + mock_create.return_value = ChatCompletionMock( + id="test-id", + choices=[ + ChoiceMock( + message=MessageMock( + content='{"no_answer": "This is an invalid response"}' + ) + ) + ], + ) + + response = client.post( + "/api/v1.0/ai/", + { + "action": "translate_fr", + "text": "Hello world", + }, + ) + + assert response.status_code == 400 + assert response.json() == {"error": ["Invalid response format"]} + + +@override_settings( + AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="test-model" +) +def test_api_ai__success(): + """ + Test the ai request with a success response + """ + user = factories.UserFactory() + + client = APIClient() + client.force_login(user) + + with patch("openai.resources.chat.completions.Completions.create") as mock_create: + mock_create.return_value = ChatCompletionMock( + id="test-id", + choices=[ + ChoiceMock(message=MessageMock(content='{"answer": "Salut le monde"}')) + ], + ) + + response = client.post( + "/api/v1.0/ai/", + { + "action": "translate_fr", + "text": "Hello world", + }, + ) + + assert response.status_code == 200 + assert response.json() == {"answer": "Salut le monde"} diff --git a/src/backend/core/urls.py b/src/backend/core/urls.py index d16e3ee0a..c78789e0f 100644 --- a/src/backend/core/urls.py +++ b/src/backend/core/urls.py @@ -52,6 +52,10 @@ r"^templates/(?P[0-9a-z-]*)/", include(template_related_router.urls), ), + re_path( + r"ai", + viewsets.AIViewSet.as_view({"post": "create"}), + ), ] ), ), diff --git a/src/backend/impress/settings.py b/src/backend/impress/settings.py index a4812c445..d0f63ded7 100755 --- a/src/backend/impress/settings.py +++ b/src/backend/impress/settings.py @@ -393,6 +393,9 @@ class Base(Configuration): ALLOW_LOGOUT_GET_METHOD = values.BooleanValue( default=True, environ_name="ALLOW_LOGOUT_GET_METHOD", environ_prefix=None ) + AI_API_KEY = values.Value(None, environ_name="AI_API_KEY", environ_prefix=None) + AI_BASE_URL = values.Value(None, environ_name="AI_BASE_URL", environ_prefix=None) + AI_MODEL = values.Value(None, environ_name="AI_MODEL", environ_prefix=None) USER_OIDC_FIELDS_TO_FULLNAME = values.ListValue( default=["first_name", "last_name"], diff --git a/src/backend/pyproject.toml b/src/backend/pyproject.toml index 0849bb1b9..d3f0f10b9 100644 --- a/src/backend/pyproject.toml +++ b/src/backend/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "jsonschema==4.23.0", "markdown==3.7", "nested-multipart-parser==1.5.0", + "openai==1.44.1", "psycopg[binary]==3.2.3", "PyJWT==2.9.0", "pypandoc==1.13", From d303f123d26a1939e76656e5397f367db9dd59c0 Mon Sep 17 00:00:00 2001 From: Anthony LC Date: Mon, 30 Sep 2024 10:05:17 +0200 Subject: [PATCH 2/6] =?UTF-8?q?fixup!=20=E2=9C=A8(backend)=20create=20ai?= =?UTF-8?q?=20endpoint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/core/api/serializers.py | 33 +- src/backend/core/api/utils.py | 137 +++++++ src/backend/core/api/viewsets.py | 193 +++------- src/backend/core/enums.py | 14 +- src/backend/core/models.py | 2 + src/backend/core/services/__init__.py | 0 src/backend/core/services/ai_services.py | 103 +++++ .../test_api_documents_ai_transform.py | 333 ++++++++++++++++ .../test_api_documents_ai_translate.py | 358 ++++++++++++++++++ .../documents/test_api_documents_retrieve.py | 4 + src/backend/core/tests/test_api_ai.py | 191 ---------- ...st_api_utils_ai_document_rate_throttles.py | 127 +++++++ .../test_api_utils_ai_user_rate_throttles.py | 146 +++++++ .../core/tests/test_models_documents.py | 16 + .../core/tests/test_services_ai_services.py | 104 +++++ src/backend/core/urls.py | 4 - src/backend/impress/settings.py | 11 + 17 files changed, 1439 insertions(+), 337 deletions(-) create mode 100644 src/backend/core/services/__init__.py create mode 100644 src/backend/core/services/ai_services.py create mode 100644 src/backend/core/tests/documents/test_api_documents_ai_transform.py create mode 100644 src/backend/core/tests/documents/test_api_documents_ai_translate.py delete mode 100644 src/backend/core/tests/test_api_ai.py create mode 100644 src/backend/core/tests/test_api_utils_ai_document_rate_throttles.py create mode 100644 src/backend/core/tests/test_api_utils_ai_user_rate_throttles.py create mode 100644 src/backend/core/tests/test_services_ai_services.py diff --git a/src/backend/core/api/serializers.py b/src/backend/core/api/serializers.py index 6154ced33..1e59111ef 100644 --- a/src/backend/core/api/serializers.py +++ b/src/backend/core/api/serializers.py @@ -8,7 +8,8 @@ from rest_framework import exceptions, serializers -from core import models +from core import enums, models +from core.services.ai_services import AI_ACTIONS class UserSerializer(serializers.ModelSerializer): @@ -350,3 +351,33 @@ class VersionFilterSerializer(serializers.Serializer): page_size = serializers.IntegerField( required=False, min_value=1, max_value=50, default=20 ) + + +class AITransformSerializer(serializers.Serializer): + """Serializer for AI transform requests.""" + + action = serializers.ChoiceField(choices=AI_ACTIONS, required=True) + text = serializers.CharField(required=True) + + def validate_text(self, value): + """Ensure the text field is not empty.""" + + if len(value.strip()) == 0: + raise serializers.ValidationError("Text field cannot be empty.") + return value + + +class AITranslateSerializer(serializers.Serializer): + """Serializer for AI translate requests.""" + + language = serializers.ChoiceField( + choices=tuple(enums.ALL_LANGUAGES.items()), required=True + ) + text = serializers.CharField(required=True) + + def validate_text(self, value): + """Ensure the text field is not empty.""" + + if len(value.strip()) == 0: + raise serializers.ValidationError("Text field cannot be empty.") + return value diff --git a/src/backend/core/api/utils.py b/src/backend/core/api/utils.py index 53d84f681..cfe2395d2 100644 --- a/src/backend/core/api/utils.py +++ b/src/backend/core/api/utils.py @@ -1,8 +1,13 @@ """Util to generate S3 authorization headers for object storage access control""" +import time + +from django.conf import settings +from django.core.cache import cache from django.core.files.storage import default_storage import botocore +from rest_framework.throttling import BaseThrottle def generate_s3_authorization_headers(key): @@ -31,3 +36,135 @@ def generate_s3_authorization_headers(key): auth.add_auth(request) return request + + +class AIDocumentRateThrottle(BaseThrottle): + """Throttle for limiting AI requests per document with backoff.""" + + def __init__(self, *args, **kwargs): + """Initialize instance attributes""" + super().__init__(*args, **kwargs) + self.rates = settings.AI_DOCUMENT_RATE_THROTTLE_RATES + self.cache_key = None + self.recent_requests_minute = 0 + self.recent_requests_hour = 0 + self.recent_requests_day = 0 + + def get_cache_key(self, request, view): + """Include document ID in the cache key""" + document_id = view.kwargs["pk"] + return f"document_{document_id}_throttle_ai" + + def allow_request(self, request, view): + """Check that limits are not exceeded""" + self.cache_key = self.get_cache_key(request, view) + + now = time.time() + history = cache.get(self.cache_key, []) + history = [ + req for req in history if req > now - 86400 + ] # Keep requests from the last day + + # Calculate recent requests + self.recent_requests_minute = len([req for req in history if req > now - 60]) + self.recent_requests_hour = len([req for req in history if req > now - 3600]) + self.recent_requests_day = len(history) + + # Check rate limits + if self.recent_requests_minute >= self.rates["minute"]: + return False # Exceeded minute limit + if self.recent_requests_hour >= self.rates["hour"]: + return False # Exceeded hour limit + if self.recent_requests_day >= self.rates["day"]: + return False # Exceeded daily limit + + # Log the request + history.append(now) + cache.set(self.cache_key, history, timeout=86400) + return True + + def wait(self): + """Implement a backoff strategy by increasing wait time with each throttle hit.""" + if self.recent_requests_day >= self.rates["day"]: + return 86400 # Throttled by day limit, wait 24 hours + if self.recent_requests_hour >= self.rates["hour"]: + return 3600 # Throttled by hour limit, wait 1 hour + if self.recent_requests_minute >= self.rates["minute"]: + return 60 # Throttled by minute limit, wait 1 minute + + return None # No backoff required + + +class AIUserRateThrottle(BaseThrottle): + """Throttle that limits requests per user or IP with backoff and rate limits.""" + + def __init__(self, *args, **kwargs): + """Initialize instance attributes""" + super().__init__(*args, **kwargs) + self.rates = settings.AI_USER_RATE_THROTTLE_RATES + self.cache_key = None + self.recent_requests_minute = 0 + self.recent_requests_hour = 0 + self.recent_requests_day = 0 + + # pylint: disable=unused-argument + def get_cache_key(self, request, view): + """Generate a cache key based on the user ID or IP for anonymous users.""" + if request.user.is_authenticated: + return f"user_{request.user.id!s}_throttle_ai" + + # Use IP address for anonymous users + ip_addr = self.get_ident(request) + return f"anonymous_{ip_addr:s}_throttle_ai" + + def allow_request(self, request, view): + """ + Check if the request should be allowed based on the user-specific or IP-specific usage. + """ + self.cache_key = self.get_cache_key(request, view) + if not self.cache_key: + return True # Allow requests if no cache key (fallback) + + now = time.time() + history = cache.get(self.cache_key, []) + # Remove entries older than a day (86400 seconds) + history = [req for req in history if req > now - 86400] + + # Calculate recent request counts + self.recent_requests_minute = len([req for req in history if req > now - 60]) + self.recent_requests_hour = len([req for req in history if req > now - 3600]) + self.recent_requests_day = len(history) + + # Check if the user has exceeded the limits + if self.recent_requests_minute >= self.rates["minute"]: + return False # Exceeded minute limit + if self.recent_requests_hour >= self.rates["hour"]: + return False # Exceeded hour limit + if self.recent_requests_day >= self.rates["day"]: + return False # Exceeded daily limit + + # If not throttled, store the request timestamp + history.append(now) + cache.set(self.cache_key, history, timeout=86400) + return True + + def wait(self): + """Calculate and return the backoff time based on the number of requests made.""" + if self.recent_requests_day >= self.rates["day"]: + return 86400 # If the day limit is exceeded, wait 24 hours + if self.recent_requests_hour >= self.rates["hour"]: + return 3600 # If the hour limit is exceeded, wait 1 hour + if self.recent_requests_minute >= self.rates["minute"]: + return 60 # If the minute limit is exceeded, wait 1 minute + + return None # No backoff required + + def get_ident(self, request): + """Return the request IP address.""" + # Use REMOTE_ADDR for IP identification, or X-Forwarded-For if behind a proxy + x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR") + if x_forwarded_for: + ip = x_forwarded_for.split(",")[0] + else: + ip = request.META.get("REMOTE_ADDR") + return ip diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 7450107fa..240461c57 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -1,6 +1,5 @@ """API endpoints""" -import json import os import re import uuid @@ -19,7 +18,6 @@ from django.http import Http404 from botocore.exceptions import ClientError -from openai import OpenAI from rest_framework import ( decorators, exceptions, @@ -34,6 +32,7 @@ ) from core import models +from core.services.ai_services import AIService from . import permissions, serializers, utils @@ -458,10 +457,7 @@ def link_configuration(self, request, *args, **kwargs): serializer = serializers.LinkDocumentSerializer( document, data=request.data, partial=True ) - if not serializer.is_valid(): - return drf_response.Response( - serializer.errors, status=status.HTTP_400_BAD_REQUEST - ) + serializer.is_valid(raise_exception=True) serializer.save() return drf_response.Response(serializer.data, status=status.HTTP_200_OK) @@ -474,10 +470,8 @@ def attachment_upload(self, request, *args, **kwargs): # Validate metadata in payload serializer = serializers.FileUploadSerializer(data=request.data) - if not serializer.is_valid(): - return drf_response.Response( - serializer.errors, status=status.HTTP_400_BAD_REQUEST - ) + serializer.is_valid(raise_exception=True) + # Extract the file extension from the original filename file = serializer.validated_data["file"] extension = os.path.splitext(file.name)[1] @@ -533,6 +527,63 @@ def retrieve_auth(self, request, *args, **kwargs): request = utils.generate_s3_authorization_headers(f"{pk:s}/{attachment_key:s}") return drf_response.Response("authorized", headers=request.headers, status=200) + @decorators.action( + detail=True, + methods=["post"], + name="Apply a transformation action on a piece of text with AI", + url_path="ai-transform", + throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle], + ) + def ai_transform(self, request, *args, **kwargs): + """ + POST /api/v1.0/documents//ai-transform + with expected data: + - text: str + - action: str [prompt, correct, rephrase, summarize] + Return JSON response with the processed text. + """ + # Check permissions first + self.get_object() + + serializer = serializers.AITransformSerializer(data=request.data) + serializer.is_valid(raise_exception=True) + + text = serializer.validated_data["text"] + action = serializer.validated_data["action"] + + response = AIService().transform(text, action) + + return drf_response.Response(response, status=status.HTTP_200_OK) + + @decorators.action( + detail=True, + methods=["post"], + name="Translate a piece of text with AI", + serializer_class=serializers.AITranslateSerializer, + url_path="ai-translate", + throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle], + ) + def ai_translate(self, request, *args, **kwargs): + """ + POST /api/v1.0/documents//ai-translate + with expected data: + - text: str + - language: str [settings.LANGUAGES] + Return JSON response with the translated text. + """ + # Check permissions first + self.get_object() + + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + text = serializer.validated_data["text"] + language = serializer.validated_data["language"] + + response = AIService().translate(text, language) + + return drf_response.Response(response, status=status.HTTP_200_OK) + class DocumentAccessViewSet( ResourceAccessViewsetMixin, @@ -783,125 +834,3 @@ def perform_create(self, serializer): invitation.document.email_invitation( language, invitation.email, invitation.role, self.request.user.email ) - - -class AIViewSet(viewsets.ViewSet): - """API ViewSet for handling AI tasks""" - - permission_classes = [permissions.IsAuthenticated] - - def create(self, request): - """ - POST /api/v1.0/ai/ with expected data: - - text: str - - action: str [prompt, correct, rephrase, summarize, - translate_en, translate_de, translate_fr] - Return JSON response with the processed text. - """ - if not request.user.is_authenticated: - raise exceptions.NotAuthenticated() - - if ( - settings.AI_BASE_URL is None - or settings.AI_API_KEY is None - or settings.AI_MODEL is None - ): - raise exceptions.ValidationError({"error": "AI configuration not set"}) - - action = request.data.get("action") - text = request.data.get("text") - - action_configs = { - "prompt": { - "system_content": ( - "Answer the prompt in markdown format. Return JSON: " - '{"answer": "Your markdown answer"}.' - "Do not provide any other information." - ), - "response_key": "answer", - }, - "correct": { - "system_content": ( - "Correct grammar and spelling of the markdown text, " - "preserving language and markdown formatting. " - 'Return JSON: {"answer": "your corrected markdown text"}.' - "Do not provide any other information." - ), - "response_key": "answer", - }, - "rephrase": { - "system_content": ( - "Rephrase the given markdown text, " - "preserving language and markdown formatting. " - 'Return JSON: {"answer": "your rephrased markdown text"}.' - "Do not provide any other information." - ), - "response_key": "answer", - }, - "summarize": { - "system_content": ( - "Summarize the markdown text, preserving language and markdown formatting. " - 'Return JSON: {"answer": "your markdown summary"}.' - "Do not provide any other information." - ), - "response_key": "answer", - }, - "translate_en": { - "system_content": ( - "Translate the markdown text to English, preserving markdown formatting. " - 'Return JSON: {"answer": "Your translated markdown text in English"}.' - "Do not provide any other information." - ), - "response_key": "answer", - }, - "translate_de": { - "system_content": ( - "Translate the markdown text to German, preserving markdown formatting. " - 'Return JSON: {"answer": "Your translated markdown text in German"}.' - "Do not provide any other information." - ), - "response_key": "answer", - }, - "translate_fr": { - "system_content": ( - "Translate the markdown text to French, preserving markdown formatting. " - 'Return JSON: {"answer": "Your translated markdown text in French"}.' - "Do not provide any other information." - ), - "response_key": "answer", - }, - } - - if action not in action_configs: - raise exceptions.ValidationError({"error": "Invalid action"}) - - config = action_configs[action] - - try: - client = OpenAI(base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY) - response = client.chat.completions.create( - model=settings.AI_MODEL, - response_format={"type": "json_object"}, - messages=[ - {"role": "system", "content": config["system_content"]}, - {"role": "user", "content": json.dumps({"mardown_input": text})}, - ], - ) - - corrected_response = json.loads(response.choices[0].message.content) - - if "answer" not in corrected_response: - raise exceptions.ValidationError("Invalid response format") - - return drf_response.Response(corrected_response, status=status.HTTP_200_OK) - - except exceptions.ValidationError as e: - return drf_response.Response( - {"error": e.detail}, status=status.HTTP_400_BAD_REQUEST - ) - - except exceptions.APIException as e: - return drf_response.Response( - {"error": f"Error processing AI response: {str(e)}"}, - status=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) diff --git a/src/backend/core/enums.py b/src/backend/core/enums.py index e67d7b5b5..8f7e70cfc 100644 --- a/src/backend/core/enums.py +++ b/src/backend/core/enums.py @@ -2,15 +2,11 @@ Core application enums declaration """ -from django.conf import global_settings, settings +from django.conf import global_settings from django.utils.translation import gettext_lazy as _ -# Django sets `LANGUAGES` by default with all supported languages. We can use it for -# the choice of languages which should not be limited to the few languages active in -# the app. +# In Django's code base, `LANGUAGES` is set by default with all supported languages. +# We can use it for the choice of languages which should not be limited to the few languages +# active in the app. # pylint: disable=no-member -ALL_LANGUAGES = getattr( - settings, - "ALL_LANGUAGES", - [(language, _(name)) for language, name in global_settings.LANGUAGES], -) +ALL_LANGUAGES = {language: _(name) for language, name in global_settings.LANGUAGES} diff --git a/src/backend/core/models.py b/src/backend/core/models.py index bea277993..49f9a0ec6 100644 --- a/src/backend/core/models.py +++ b/src/backend/core/models.py @@ -508,6 +508,8 @@ def get_abilities(self, user): can_get = bool(roles) return { + "ai_transform": is_owner_or_admin or is_editor, + "ai_translate": is_owner_or_admin or is_editor, "attachment_upload": is_owner_or_admin or is_editor, "destroy": RoleChoices.OWNER in roles, "link_configuration": is_owner_or_admin, diff --git a/src/backend/core/services/__init__.py b/src/backend/core/services/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/core/services/ai_services.py b/src/backend/core/services/ai_services.py new file mode 100644 index 000000000..2745b1085 --- /dev/null +++ b/src/backend/core/services/ai_services.py @@ -0,0 +1,103 @@ +"""AI services.""" + +import json +import re + +from django.conf import settings + +from openai import OpenAI +from rest_framework import serializers + +from core import enums + +AI_ACTIONS = { + "prompt": ( + "Answer the prompt in markdown format. Return JSON: " + '{"answer": "Your markdown answer"}. ' + "Do not provide any other information." + ), + "correct": ( + "Correct grammar and spelling of the markdown text, " + "preserving language and markdown formatting. " + 'Return JSON: {"answer": "your corrected markdown text"}. ' + "Do not provide any other information." + ), + "rephrase": ( + "Rephrase the given markdown text, " + "preserving language and markdown formatting. " + 'Return JSON: {"answer": "your rephrased markdown text"}. ' + "Do not provide any other information." + ), + "summarize": ( + "Summarize the markdown text, preserving language and markdown formatting. " + 'Return JSON: {"answer": "your markdown summary"}. ' + "Do not provide any other information." + ), +} + +AI_TRANSLATE = ( + "Translate the markdown text to {language:s}, preserving markdown formatting. " + 'Return JSON: {{"answer": "your translated markdown text in {language:s}"}}. ' + "Do not provide any other information." +) + + +class AIService: + """Service class for AI-related operations.""" + + def __init__(self): + """Ensure that the AI configuration is set properly.""" + if ( + settings.AI_BASE_URL is None + or settings.AI_API_KEY is None + or settings.AI_MODEL is None + ): + raise serializers.ValidationError("AI configuration not set") + self.client = OpenAI(base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY) + + def transform(self, text, action): + """Call the OpenAI API with the transform prompt and return the response.""" + system_content = AI_ACTIONS[action] + response = self.client.chat.completions.create( + model=settings.AI_MODEL, + response_format={"type": "json_object"}, + messages=[ + {"role": "system", "content": system_content}, + {"role": "user", "content": json.dumps({"markdown_input": text})}, + ], + ) + + content = response.choices[0].message.content + sanitized_content = re.sub(r"(?[0-9a-z-]*)/", include(template_related_router.urls), ), - re_path( - r"ai", - viewsets.AIViewSet.as_view({"post": "create"}), - ), ] ), ), diff --git a/src/backend/impress/settings.py b/src/backend/impress/settings.py index d0f63ded7..623f1c9e5 100755 --- a/src/backend/impress/settings.py +++ b/src/backend/impress/settings.py @@ -397,6 +397,17 @@ class Base(Configuration): AI_BASE_URL = values.Value(None, environ_name="AI_BASE_URL", environ_prefix=None) AI_MODEL = values.Value(None, environ_name="AI_MODEL", environ_prefix=None) + AI_DOCUMENT_RATE_THROTTLE_RATES = { + "minute": 5, + "hour": 100, + "day": 500, + } + AI_USER_RATE_THROTTLE_RATES = { + "minute": 3, + "hour": 50, + "day": 200, + } + USER_OIDC_FIELDS_TO_FULLNAME = values.ListValue( default=["first_name", "last_name"], environ_name="USER_OIDC_FIELDS_TO_FULLNAME", From fea99a4db54e09fd5f9c5abbcc83bf20b4b1957e Mon Sep 17 00:00:00 2001 From: Samuel Paccoud - DINUM Date: Wed, 9 Oct 2024 11:38:38 +0200 Subject: [PATCH 3/6] =?UTF-8?q?fixup!=20fixup!=20=E2=9C=A8(backend)=20crea?= =?UTF-8?q?te=20ai=20endpoint?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/core/api/viewsets.py | 21 +++++++++++++- .../test_api_documents_ai_transform.py | 21 +++++++------- .../test_api_documents_ai_translate.py | 28 +++++++++---------- .../core/tests/test_services_ai_services.py | 8 +++--- 4 files changed, 47 insertions(+), 31 deletions(-) diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index 240461c57..8024e7436 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -22,6 +22,7 @@ decorators, exceptions, filters, + metadata, mixins, pagination, status, @@ -31,7 +32,7 @@ response as drf_response, ) -from core import models +from core import enums, models from core.services.ai_services import AIService from . import permissions, serializers, utils @@ -304,6 +305,23 @@ def perform_update(self, serializer): serializer.save() +class DocumentMetadata(metadata.SimpleMetadata): + """Custom metadata class to add information""" + + def determine_metadata(self, request, view): + """Add language choices only for the list endpoint.""" + simple_metadata = super().determine_metadata(request, view) + + if request.path.endswith("/documents/"): + simple_metadata["actions"]["POST"]["language"] = { + "choices": [ + {"value": code, "display_name": name} + for code, name in enums.ALL_LANGUAGES.items() + ] + } + return simple_metadata + + class DocumentViewSet( ResourceViewsetMixin, mixins.CreateModelMixin, @@ -321,6 +339,7 @@ class DocumentViewSet( resource_field_name = "document" queryset = models.Document.objects.all() ordering = ["-updated_at"] + metadata_class = DocumentMetadata def list(self, request, *args, **kwargs): """Restrict resources returned by the list endpoint""" diff --git a/src/backend/core/tests/documents/test_api_documents_ai_transform.py b/src/backend/core/tests/documents/test_api_documents_ai_transform.py index 4562f6df3..6dfe63130 100644 --- a/src/backend/core/tests/documents/test_api_documents_ai_transform.py +++ b/src/backend/core/tests/documents/test_api_documents_ai_transform.py @@ -2,7 +2,6 @@ Test AI transform API endpoint for users in impress's core app. """ -import json from unittest.mock import MagicMock, patch from django.core.cache import cache @@ -57,9 +56,9 @@ def test_api_documents_ai_transform_anonymous_success(mock_create): """ document = factories.DocumentFactory(link_reach="public", link_role="editor") - answer = {"answer": "Salut"} + answer = '{"answer": "Salut"}' mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(answer)))] + choices=[MagicMock(message=MagicMock(content=answer))] ) url = f"/api/v1.0/documents/{document.id!s}/ai-transform/" @@ -134,9 +133,9 @@ def test_api_documents_ai_transform_authenticated_success(mock_create, reach, ro document = factories.DocumentFactory(link_reach=reach, link_role=role) - answer = {"answer": "Salut"} + answer = '{"answer": "Salut"}' mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(answer)))] + choices=[MagicMock(message=MagicMock(content=answer))] ) url = f"/api/v1.0/documents/{document.id!s}/ai-transform/" @@ -209,9 +208,9 @@ def test_api_documents_ai_transform_success(mock_create, via, role, mock_user_te document=document, team="lasuite", role=role ) - answer = {"answer": "Salut"} + answer = '{"answer": "Salut"}' mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(answer)))] + choices=[MagicMock(message=MagicMock(content=answer))] ) url = f"/api/v1.0/documents/{document.id!s}/ai-transform/" @@ -277,9 +276,9 @@ def test_api_documents_ai_transform_throttling_document(mock_create): client = APIClient() document = factories.DocumentFactory(link_reach="public", link_role="editor") - answer = {"answer": "Salut"} + answer = '{"answer": "Salut"}' mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(answer)))] + choices=[MagicMock(message=MagicMock(content=answer))] ) url = f"/api/v1.0/documents/{document.id!s}/ai-transform/" @@ -311,9 +310,9 @@ def test_api_documents_ai_transform_throttling_user(mock_create): client = APIClient() client.force_login(user) - answer = {"answer": "Salut"} + answer = '{"answer": "Salut"}' mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(answer)))] + choices=[MagicMock(message=MagicMock(content=answer))] ) for _ in range(3): diff --git a/src/backend/core/tests/documents/test_api_documents_ai_translate.py b/src/backend/core/tests/documents/test_api_documents_ai_translate.py index 7e43622a9..2b3befcbf 100644 --- a/src/backend/core/tests/documents/test_api_documents_ai_translate.py +++ b/src/backend/core/tests/documents/test_api_documents_ai_translate.py @@ -2,7 +2,6 @@ Test AI translate API endpoint for users in impress's core app. """ -import json from unittest.mock import MagicMock, patch from django.core.cache import cache @@ -30,14 +29,13 @@ def test_api_documents_ai_translate_viewset_options_metadata(): client = APIClient() client.force_login(user) - document = factories.DocumentFactory(link_reach="public", link_role="editor") + factories.DocumentFactory(link_reach="public", link_role="editor") - url = f"/api/v1.0/documents/{document.id!s}/ai-translate/" - response = APIClient().options(url) + response = APIClient().options("/api/v1.0/documents/") assert response.status_code == 200 metadata = response.json() - assert metadata["name"] == "Translate a piece of text with AI" + assert metadata["name"] == "Document List" assert metadata["actions"]["POST"]["language"]["choices"][0] == { "value": "af", "display_name": "Afrikaans", @@ -78,9 +76,9 @@ def test_api_documents_ai_translate_anonymous_success(mock_create): """ document = factories.DocumentFactory(link_reach="public", link_role="editor") - answer = {"answer": "Salut"} + answer = '{"answer": "Salut"}' mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(answer)))] + choices=[MagicMock(message=MagicMock(content=answer))] ) url = f"/api/v1.0/documents/{document.id!s}/ai-translate/" @@ -155,9 +153,9 @@ def test_api_documents_ai_translate_authenticated_success(mock_create, reach, ro document = factories.DocumentFactory(link_reach=reach, link_role=role) - answer = {"answer": "Salut"} + answer = '{"answer": "Salut"}' mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(answer)))] + choices=[MagicMock(message=MagicMock(content=answer))] ) url = f"/api/v1.0/documents/{document.id!s}/ai-translate/" @@ -232,9 +230,9 @@ def test_api_documents_ai_translate_success(mock_create, via, role, mock_user_te document=document, team="lasuite", role=role ) - answer = {"answer": "Salut"} + answer = '{"answer": "Salut"}' mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(answer)))] + choices=[MagicMock(message=MagicMock(content=answer))] ) url = f"/api/v1.0/documents/{document.id!s}/ai-translate/" @@ -302,9 +300,9 @@ def test_api_documents_ai_translate_throttling_document(mock_create): client = APIClient() document = factories.DocumentFactory(link_reach="public", link_role="editor") - answer = {"answer": "Salut"} + answer = '{"answer": "Salut"}' mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(answer)))] + choices=[MagicMock(message=MagicMock(content=answer))] ) url = f"/api/v1.0/documents/{document.id!s}/ai-translate/" @@ -336,9 +334,9 @@ def test_api_documents_ai_translate_throttling_user(mock_create): client = APIClient() client.force_login(user) - answer = {"answer": "Salut"} + answer = '{"answer": "Salut"}' mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(answer)))] + choices=[MagicMock(message=MagicMock(content=answer))] ) for _ in range(3): diff --git a/src/backend/core/tests/test_services_ai_services.py b/src/backend/core/tests/test_services_ai_services.py index f11f4504b..0dc866243 100644 --- a/src/backend/core/tests/test_services_ai_services.py +++ b/src/backend/core/tests/test_services_ai_services.py @@ -77,9 +77,9 @@ def test_api_ai__client_invalid_response(mock_create): def test_api_ai__success(mock_create): """The AI request should work as expect when called with valid arguments.""" - answer = {"answer": "Salut"} + answer = '{"answer": "Salut"}' mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(answer)))] + choices=[MagicMock(message=MagicMock(content=answer))] ) response = AIService().transform("hello", "prompt") @@ -94,9 +94,9 @@ def test_api_ai__success(mock_create): def test_api_ai__success_sanitize(mock_create): """The AI response should be sanitized""" - answer = {"answer": "Salut\\n \tle \nmonde"} + answer = '{"answer": "Salut\\n \tle \nmonde"}' mock_create.return_value = MagicMock( - choices=[MagicMock(message=MagicMock(content=json.dumps(answer)))] + choices=[MagicMock(message=MagicMock(content=answer))] ) response = AIService().transform("hello", "prompt") From 539efe5b7eefbbf311e8e6f49179adb10c5ae8d9 Mon Sep 17 00:00:00 2001 From: Anthony LC Date: Fri, 20 Sep 2024 22:45:06 +0200 Subject: [PATCH 4/6] =?UTF-8?q?=E2=9C=A8(frontend)=20add=20ai=20blocknote?= =?UTF-8?q?=20feature?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add AI button to the editor toolbar. We can use AI to generate content with our editor. A list of predefined actions are available to use. --- CHANGELOG.md | 1 + .../__tests__/app-impress/doc-editor.spec.ts | 47 +++++ .../features/docs/doc-editor/api/useAI.tsx | 43 +++++ .../docs/doc-editor/components/AIButton.tsx | 179 ++++++++++++++++++ .../components/BlockNoteToolbar.tsx | 5 + .../apps/impress/src/i18n/translations.json | 11 +- 6 files changed, 285 insertions(+), 1 deletion(-) create mode 100644 src/frontend/apps/impress/src/features/docs/doc-editor/api/useAI.tsx create mode 100644 src/frontend/apps/impress/src/features/docs/doc-editor/components/AIButton.tsx diff --git a/CHANGELOG.md b/CHANGELOG.md index af20e9b7c..13efa51dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to ## Added +- ✨AI to doc editor #250 - ✨(backend) add name fields to the user synchronized with OIDC #301 - ✨(ci) add security scan #291 - ✨(frontend) Activate versions feature #240 diff --git a/src/frontend/apps/e2e/__tests__/app-impress/doc-editor.spec.ts b/src/frontend/apps/e2e/__tests__/app-impress/doc-editor.spec.ts index a39362e23..3656d5675 100644 --- a/src/frontend/apps/e2e/__tests__/app-impress/doc-editor.spec.ts +++ b/src/frontend/apps/e2e/__tests__/app-impress/doc-editor.spec.ts @@ -181,4 +181,51 @@ test.describe('Doc Editor', () => { /http:\/\/localhost:8083\/media\/.*\/attachments\/.*.png/, ); }); + + test('it checks the AI buttons', async ({ page }) => { + await page.route(/.*\/ai\//, async (route) => { + const request = route.request(); + if (request.method().includes('POST')) { + await route.fulfill({ + json: { + answer: 'Bonjour le monde', + }, + }); + } else { + await route.continue(); + } + }); + + await goToGridDoc(page); + + await page.locator('.bn-block-outer').last().fill('Hello World'); + + const editor = page.locator('.ProseMirror'); + await editor.getByText('Hello').dblclick(); + + await page.getByRole('button', { name: 'AI' }).click(); + + await expect( + page.getByRole('menuitem', { name: 'Use as prompt' }), + ).toBeVisible(); + await expect( + page.getByRole('menuitem', { name: 'Rephrase' }), + ).toBeVisible(); + await expect( + page.getByRole('menuitem', { name: 'Summarize' }), + ).toBeVisible(); + await expect(page.getByRole('menuitem', { name: 'Correct' })).toBeVisible(); + await expect( + page.getByRole('menuitem', { name: 'Language' }), + ).toBeVisible(); + + await page.getByRole('menuitem', { name: 'Language' }).hover(); + await expect(page.getByRole('menuitem', { name: 'English' })).toBeVisible(); + await expect(page.getByRole('menuitem', { name: 'French' })).toBeVisible(); + await expect(page.getByRole('menuitem', { name: 'German' })).toBeVisible(); + + await page.getByRole('menuitem', { name: 'English' }).click(); + + await expect(editor.getByText('Bonjour le monde')).toBeVisible(); + }); }); diff --git a/src/frontend/apps/impress/src/features/docs/doc-editor/api/useAI.tsx b/src/frontend/apps/impress/src/features/docs/doc-editor/api/useAI.tsx new file mode 100644 index 000000000..ae67fcf27 --- /dev/null +++ b/src/frontend/apps/impress/src/features/docs/doc-editor/api/useAI.tsx @@ -0,0 +1,43 @@ +import { useMutation } from '@tanstack/react-query'; + +import { APIError, errorCauses, fetchAPI } from '@/api'; + +export type AIActions = + | 'prompt' + | 'rephrase' + | 'summarize' + | 'translate' + | 'correct' + | 'translate_fr' + | 'translate_en' + | 'translate_de'; + +export type AIParams = { + text: string; + action: AIActions; +}; + +export type AIResponse = { + answer: string; +}; + +export const AI = async ({ ...params }: AIParams): Promise => { + const response = await fetchAPI(`ai/`, { + method: 'POST', + body: JSON.stringify({ + ...params, + }), + }); + + if (!response.ok) { + throw new APIError('Failed to request ai', await errorCauses(response)); + } + + return response.json() as Promise; +}; + +export function useAI() { + return useMutation({ + mutationFn: AI, + }); +} diff --git a/src/frontend/apps/impress/src/features/docs/doc-editor/components/AIButton.tsx b/src/frontend/apps/impress/src/features/docs/doc-editor/components/AIButton.tsx new file mode 100644 index 000000000..6e2feddc0 --- /dev/null +++ b/src/frontend/apps/impress/src/features/docs/doc-editor/components/AIButton.tsx @@ -0,0 +1,179 @@ +import { + ComponentProps, + useBlockNoteEditor, + useComponentsContext, + useSelectedBlocks, +} from '@blocknote/react'; +import { Loader } from '@openfun/cunningham-react'; +import { ReactNode, useCallback, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; + +import { Box, Text } from '@/components'; + +import { AIActions, useAI } from '../api/useAI'; + +export function AIGroupButton() { + const editor = useBlockNoteEditor(); + const Components = useComponentsContext(); + const selectedBlocks = useSelectedBlocks(editor); + const { t } = useTranslation(); + + const show = useMemo(() => { + return !!selectedBlocks.find((block) => block.content !== undefined); + }, [selectedBlocks]); + + if (!show || !editor.isEditable || !Components) { + return null; + } + + return ( + + + + auto_awesome + + } + /> + + + + text_fields + + } + > + {t('Use as prompt')} + + + refresh + + } + > + {t('Rephrase')} + + + summarize + + } + > + {t('Summarize')} + + + check + + } + > + {t('Correct')} + + + + + + + translate + + {t('Language')} + + + + + {t('English')} + {t('French')} + {t('German')} + + + + + ); +} + +/** + * Item is derived from Mantime, some props seem lacking or incorrect. + */ +type ItemDefault = ComponentProps['Generic']['Menu']['Item']; +type ItemProps = Omit & { + rightSection?: ReactNode; + closeMenuOnClick?: boolean; + onClick: (e: React.MouseEvent) => void; +}; + +interface AIMenuItemProps { + action: AIActions; + children: ReactNode; + icon?: ReactNode; +} + +const AIMenuItem = ({ action, children, icon }: AIMenuItemProps) => { + const editor = useBlockNoteEditor(); + const Components = useComponentsContext(); + const { mutateAsync: requestAI, isPending } = useAI(); + + const handleAIAction = useCallback(async () => { + const selectedBlocks = editor.getSelection()?.blocks; + + if (!selectedBlocks || selectedBlocks.length === 0) { + return; + } + + const markdown = await editor.blocksToMarkdownLossy(selectedBlocks); + const responseAI = await requestAI({ + text: markdown, + action, + }); + + if (!responseAI.answer) { + return; + } + + const blockMarkdown = await editor.tryParseMarkdownToBlocks( + responseAI.answer, + ); + editor.replaceBlocks(selectedBlocks, blockMarkdown); + }, [editor, requestAI, action]); + + if (!Components) { + return null; + } + + const Item = Components.Generic.Menu.Item as React.FC; + + return ( + { + e.stopPropagation(); + void handleAIAction(); + }} + rightSection={isPending ? : undefined} + > + {children} + + ); +}; diff --git a/src/frontend/apps/impress/src/features/docs/doc-editor/components/BlockNoteToolbar.tsx b/src/frontend/apps/impress/src/features/docs/doc-editor/components/BlockNoteToolbar.tsx index d7deeee2c..d35637622 100644 --- a/src/frontend/apps/impress/src/features/docs/doc-editor/components/BlockNoteToolbar.tsx +++ b/src/frontend/apps/impress/src/features/docs/doc-editor/components/BlockNoteToolbar.tsx @@ -14,6 +14,8 @@ import React from 'react'; import { MarkdownButton } from './MarkdownButton'; +import { AIGroupButton } from './AIButton'; + export const BlockNoteToolbar = () => { return ( { + {/* Extra button to do some AI powered actions */} + + {/* Extra button to convert from markdown to json */} diff --git a/src/frontend/apps/impress/src/i18n/translations.json b/src/frontend/apps/impress/src/i18n/translations.json index 5a94e33db..f8d6784b5 100644 --- a/src/frontend/apps/impress/src/i18n/translations.json +++ b/src/frontend/apps/impress/src/i18n/translations.json @@ -143,7 +143,16 @@ "accessibility-dinum-services": "DINUM s'engage à rendre accessibles ses services numériques, conformément à l'article 47 de la loi n° 2005-102 du 11 février 2005.", "accessibility-form-defenseurdesdroits": "Écrire un message au<1>Défenseur des droits", "accessibility-not-audit": "docs.numerique.gouv.fr n'est pas en conformité avec le RGAA 4.1. Le site n'a pas encore été audité.", - "you have reported to the website manager a lack of accessibility that prevents you from accessing content or one of the services of the portal and you have not received a satisfactory response.": "vous avez signalé au responsable du site internet un défaut d'accessibilité qui vous empêche d'accéder à un contenu ou à un des services du portail et vous n'avez pas obtenu de réponse satisfaisante." + "you have reported to the website manager a lack of accessibility that prevents you from accessing content or one of the services of the portal and you have not received a satisfactory response.": "vous avez signalé au responsable du site internet un défaut d'accessibilité qui vous empêche d'accéder à un contenu ou à un des services du portail et vous n'avez pas obtenu de réponse satisfaisante.", + "AI Actions": "Actions IA", + "Use as prompt": "Utiliser comme prompt", + "Rephrase": "Reformuler", + "Summarize": "Résumer", + "Correct": "Corriger", + "Translate": "Traduire", + "English": "Anglais", + "French": "Français", + "German": "Allemand" } } } From 5ec13728bc5ecdd235b270ebf371e412dd5755d5 Mon Sep 17 00:00:00 2001 From: Anthony LC Date: Mon, 30 Sep 2024 10:04:40 +0200 Subject: [PATCH 5/6] =?UTF-8?q?fixup!=20=E2=9C=A8(frontend)=20add=20ai=20b?= =?UTF-8?q?locknote=20feature?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../docs/doc-editor/components/AIButton.tsx | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/frontend/apps/impress/src/features/docs/doc-editor/components/AIButton.tsx b/src/frontend/apps/impress/src/features/docs/doc-editor/components/AIButton.tsx index 6e2feddc0..16cd2518c 100644 --- a/src/frontend/apps/impress/src/features/docs/doc-editor/components/AIButton.tsx +++ b/src/frontend/apps/impress/src/features/docs/doc-editor/components/AIButton.tsx @@ -142,19 +142,24 @@ const AIMenuItem = ({ action, children, icon }: AIMenuItemProps) => { } const markdown = await editor.blocksToMarkdownLossy(selectedBlocks); - const responseAI = await requestAI({ - text: markdown, - action, - }); - if (!responseAI.answer) { - return; + try { + const responseAI = await requestAI({ + text: markdown, + action, + }); + + if (!responseAI.answer) { + return; + } + + const blockMarkdown = await editor.tryParseMarkdownToBlocks( + responseAI.answer, + ); + editor.replaceBlocks(selectedBlocks, blockMarkdown); + } catch (error) { + console.error(error); } - - const blockMarkdown = await editor.tryParseMarkdownToBlocks( - responseAI.answer, - ); - editor.replaceBlocks(selectedBlocks, blockMarkdown); }, [editor, requestAI, action]); if (!Components) { From 978d71124a04168325a9ab569bcce829b7b4330e Mon Sep 17 00:00:00 2001 From: Anthony LC Date: Thu, 26 Sep 2024 09:35:11 +0200 Subject: [PATCH 6/6] =?UTF-8?q?=F0=9F=94=A7(helm)=20add=20ai=20setting=20t?= =?UTF-8?q?o=20environments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the ai setting to the environments. --- src/helm/env.d/preprod/values.impress.yaml.gotmpl | 11 ++++++++++- src/helm/env.d/production/values.impress.yaml.gotmpl | 11 ++++++++++- src/helm/env.d/staging/values.impress.yaml.gotmpl | 9 +++++++++ src/helm/impress/templates/secrets.yaml | 2 ++ 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/helm/env.d/preprod/values.impress.yaml.gotmpl b/src/helm/env.d/preprod/values.impress.yaml.gotmpl index cad2efdbb..0b2d9f9e9 100644 --- a/src/helm/env.d/preprod/values.impress.yaml.gotmpl +++ b/src/helm/env.d/preprod/values.impress.yaml.gotmpl @@ -8,6 +8,15 @@ backend: argocd.argoproj.io/hook: PreSync argocd.argoproj.io/hook-delete-policy: HookSucceeded envVars: + AI_API_KEY: + secretKeyRef: + name: backend + key: AI_API_KEY + AI_BASE_URL: + secretKeyRef: + name: backend + key: AI_BASE_URL + AI_MODEL: meta-llama/Meta-Llama-3.1-70B-Instruct DJANGO_CSRF_TRUSTED_ORIGINS: http://impress-preprod.beta.numerique.gouv.fr,https://impress-preprod.beta.numerique.gouv.fr DJANGO_CONFIGURATION: Production DJANGO_ALLOWED_HOSTS: "*" @@ -171,4 +180,4 @@ ingressMedia: serviceMedia: host: s3.margaret-hamilton.indiehosters.net - port: 443 \ No newline at end of file + port: 443 diff --git a/src/helm/env.d/production/values.impress.yaml.gotmpl b/src/helm/env.d/production/values.impress.yaml.gotmpl index 5ca48a096..edceebd84 100644 --- a/src/helm/env.d/production/values.impress.yaml.gotmpl +++ b/src/helm/env.d/production/values.impress.yaml.gotmpl @@ -8,6 +8,15 @@ backend: argocd.argoproj.io/hook: PostSync argocd.argoproj.io/hook-delete-policy: HookSucceeded envVars: + AI_API_KEY: + secretKeyRef: + name: backend + key: AI_API_KEY + AI_BASE_URL: + secretKeyRef: + name: backend + key: AI_BASE_URL + AI_MODEL: meta-llama/Meta-Llama-3.1-70B-Instruct DJANGO_CSRF_TRUSTED_ORIGINS: https://docs.numerique.gouv.fr DJANGO_CONFIGURATION: Production DJANGO_ALLOWED_HOSTS: "*" @@ -171,4 +180,4 @@ ingressMedia: serviceMedia: host: s3.hedy-lamarr.indiehosters.net - port: 443 \ No newline at end of file + port: 443 diff --git a/src/helm/env.d/staging/values.impress.yaml.gotmpl b/src/helm/env.d/staging/values.impress.yaml.gotmpl index 9cc68bf81..5eb0464d1 100644 --- a/src/helm/env.d/staging/values.impress.yaml.gotmpl +++ b/src/helm/env.d/staging/values.impress.yaml.gotmpl @@ -8,6 +8,15 @@ backend: argocd.argoproj.io/hook: PreSync argocd.argoproj.io/hook-delete-policy: HookSucceeded envVars: + AI_API_KEY: + secretKeyRef: + name: backend + key: AI_API_KEY + AI_BASE_URL: + secretKeyRef: + name: backend + key: AI_BASE_URL + AI_MODEL: meta-llama/Meta-Llama-3.1-70B-Instruct DJANGO_CSRF_TRUSTED_ORIGINS: http://impress-staging.beta.numerique.gouv.fr,https://impress-staging.beta.numerique.gouv.fr DJANGO_CONFIGURATION: Production DJANGO_ALLOWED_HOSTS: "*" diff --git a/src/helm/impress/templates/secrets.yaml b/src/helm/impress/templates/secrets.yaml index efe2dbb0e..c308fca80 100644 --- a/src/helm/impress/templates/secrets.yaml +++ b/src/helm/impress/templates/secrets.yaml @@ -19,3 +19,5 @@ stringData: {{- end }} OIDC_RP_CLIENT_ID: {{ .Values.oidc.clientId }} OIDC_RP_CLIENT_SECRET: {{ .Values.oidc.clientSecret }} + AI_API_KEY: {{ .Values.aiApiKey }} + AI_BASE_URL: {{ .Values.aiBaseUrl }}