Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨AI to doc editor #250

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to

## Added

- ✨AI to doc editor #250
- ✨(backend) add name fields to the user synchronized with OIDC #301
- ✨(ci) add security scan #291
- ✨(frontend) Activate versions feature #240
Expand Down
4 changes: 4 additions & 0 deletions env.d/development/common.dist
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ LOGOUT_REDIRECT_URL=http://localhost:3000

OIDC_REDIRECT_ALLOWED_HOSTS=["http://localhost:8083", "http://localhost:3000"]
OIDC_AUTH_REQUEST_EXTRA_PARAMS={"acr_values": "eidas1"}

AI_BASE_URL=https://openaiendpoint.com
AI_API_KEY=password
AI_MODEL=llama
33 changes: 32 additions & 1 deletion src/backend/core/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from rest_framework import exceptions, serializers

from core import models
from core import enums, models
from core.services.ai_services import AI_ACTIONS


class UserSerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -350,3 +351,33 @@ class VersionFilterSerializer(serializers.Serializer):
page_size = serializers.IntegerField(
required=False, min_value=1, max_value=50, default=20
)


class AITransformSerializer(serializers.Serializer):
"""Serializer for AI transform requests."""

action = serializers.ChoiceField(choices=AI_ACTIONS, required=True)
text = serializers.CharField(required=True)

def validate_text(self, value):
"""Ensure the text field is not empty."""

if len(value.strip()) == 0:
raise serializers.ValidationError("Text field cannot be empty.")
return value


class AITranslateSerializer(serializers.Serializer):
"""Serializer for AI translate requests."""

language = serializers.ChoiceField(
choices=tuple(enums.ALL_LANGUAGES.items()), required=True
)
text = serializers.CharField(required=True)

def validate_text(self, value):
"""Ensure the text field is not empty."""

if len(value.strip()) == 0:
raise serializers.ValidationError("Text field cannot be empty.")
return value
137 changes: 137 additions & 0 deletions src/backend/core/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""Util to generate S3 authorization headers for object storage access control"""

import time

from django.conf import settings
from django.core.cache import cache
from django.core.files.storage import default_storage

import botocore
from rest_framework.throttling import BaseThrottle


def generate_s3_authorization_headers(key):
Expand Down Expand Up @@ -31,3 +36,135 @@ def generate_s3_authorization_headers(key):
auth.add_auth(request)

return request


class AIDocumentRateThrottle(BaseThrottle):
"""Throttle for limiting AI requests per document with backoff."""

def __init__(self, *args, **kwargs):
"""Initialize instance attributes"""
super().__init__(*args, **kwargs)
self.rates = settings.AI_DOCUMENT_RATE_THROTTLE_RATES
self.cache_key = None
self.recent_requests_minute = 0
self.recent_requests_hour = 0
self.recent_requests_day = 0

def get_cache_key(self, request, view):
"""Include document ID in the cache key"""
document_id = view.kwargs["pk"]
return f"document_{document_id}_throttle_ai"

def allow_request(self, request, view):
"""Check that limits are not exceeded"""
self.cache_key = self.get_cache_key(request, view)

now = time.time()
history = cache.get(self.cache_key, [])
history = [
req for req in history if req > now - 86400
] # Keep requests from the last day

# Calculate recent requests
self.recent_requests_minute = len([req for req in history if req > now - 60])
self.recent_requests_hour = len([req for req in history if req > now - 3600])
self.recent_requests_day = len(history)

# Check rate limits
if self.recent_requests_minute >= self.rates["minute"]:
return False # Exceeded minute limit
if self.recent_requests_hour >= self.rates["hour"]:
return False # Exceeded hour limit
if self.recent_requests_day >= self.rates["day"]:
return False # Exceeded daily limit

# Log the request
history.append(now)
cache.set(self.cache_key, history, timeout=86400)
return True

def wait(self):
"""Implement a backoff strategy by increasing wait time with each throttle hit."""
if self.recent_requests_day >= self.rates["day"]:
return 86400 # Throttled by day limit, wait 24 hours
if self.recent_requests_hour >= self.rates["hour"]:
return 3600 # Throttled by hour limit, wait 1 hour
if self.recent_requests_minute >= self.rates["minute"]:
return 60 # Throttled by minute limit, wait 1 minute

return None # No backoff required


class AIUserRateThrottle(BaseThrottle):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you want to define a AIBaseThrottle to prevent duplicated code here?

"""Throttle that limits requests per user or IP with backoff and rate limits."""

def __init__(self, *args, **kwargs):
"""Initialize instance attributes"""
super().__init__(*args, **kwargs)
self.rates = settings.AI_USER_RATE_THROTTLE_RATES
self.cache_key = None
self.recent_requests_minute = 0
self.recent_requests_hour = 0
self.recent_requests_day = 0

# pylint: disable=unused-argument
def get_cache_key(self, request, view):
"""Generate a cache key based on the user ID or IP for anonymous users."""
if request.user.is_authenticated:
return f"user_{request.user.id!s}_throttle_ai"

# Use IP address for anonymous users
ip_addr = self.get_ident(request)
return f"anonymous_{ip_addr:s}_throttle_ai"

def allow_request(self, request, view):
"""
Check if the request should be allowed based on the user-specific or IP-specific usage.
"""
self.cache_key = self.get_cache_key(request, view)
if not self.cache_key:
return True # Allow requests if no cache key (fallback)

now = time.time()
history = cache.get(self.cache_key, [])
# Remove entries older than a day (86400 seconds)
history = [req for req in history if req > now - 86400]

# Calculate recent request counts
self.recent_requests_minute = len([req for req in history if req > now - 60])
self.recent_requests_hour = len([req for req in history if req > now - 3600])
self.recent_requests_day = len(history)

# Check if the user has exceeded the limits
if self.recent_requests_minute >= self.rates["minute"]:
return False # Exceeded minute limit
if self.recent_requests_hour >= self.rates["hour"]:
return False # Exceeded hour limit
if self.recent_requests_day >= self.rates["day"]:
return False # Exceeded daily limit

# If not throttled, store the request timestamp
history.append(now)
cache.set(self.cache_key, history, timeout=86400)
return True

def wait(self):
"""Calculate and return the backoff time based on the number of requests made."""
if self.recent_requests_day >= self.rates["day"]:
return 86400 # If the day limit is exceeded, wait 24 hours
if self.recent_requests_hour >= self.rates["hour"]:
return 3600 # If the hour limit is exceeded, wait 1 hour
if self.recent_requests_minute >= self.rates["minute"]:
return 60 # If the minute limit is exceeded, wait 1 minute

return None # No backoff required

def get_ident(self, request):
"""Return the request IP address."""
# Use REMOTE_ADDR for IP identification, or X-Forwarded-For if behind a proxy
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
if x_forwarded_for:
ip = x_forwarded_for.split(",")[0]
else:
ip = request.META.get("REMOTE_ADDR")
return ip
90 changes: 81 additions & 9 deletions src/backend/core/api/viewsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
decorators,
exceptions,
filters,
metadata,
mixins,
pagination,
status,
Expand All @@ -31,7 +32,8 @@
response as drf_response,
)

from core import models
from core import enums, models
from core.services.ai_services import AIService

from . import permissions, serializers, utils

Expand Down Expand Up @@ -303,6 +305,23 @@ def perform_update(self, serializer):
serializer.save()


class DocumentMetadata(metadata.SimpleMetadata):
"""Custom metadata class to add information"""

def determine_metadata(self, request, view):
"""Add language choices only for the list endpoint."""
simple_metadata = super().determine_metadata(request, view)

if request.path.endswith("/documents/"):
simple_metadata["actions"]["POST"]["language"] = {
"choices": [
{"value": code, "display_name": name}
for code, name in enums.ALL_LANGUAGES.items()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that true the AI can manage the same set of language as the ones available in the backend?

]
}
return simple_metadata


class DocumentViewSet(
ResourceViewsetMixin,
mixins.CreateModelMixin,
Expand All @@ -320,6 +339,7 @@ class DocumentViewSet(
resource_field_name = "document"
queryset = models.Document.objects.all()
ordering = ["-updated_at"]
metadata_class = DocumentMetadata

def list(self, request, *args, **kwargs):
"""Restrict resources returned by the list endpoint"""
Expand Down Expand Up @@ -456,10 +476,7 @@ def link_configuration(self, request, *args, **kwargs):
serializer = serializers.LinkDocumentSerializer(
document, data=request.data, partial=True
)
if not serializer.is_valid():
return drf_response.Response(
serializer.errors, status=status.HTTP_400_BAD_REQUEST
)
serializer.is_valid(raise_exception=True)

serializer.save()
return drf_response.Response(serializer.data, status=status.HTTP_200_OK)
Expand All @@ -472,10 +489,8 @@ def attachment_upload(self, request, *args, **kwargs):

# Validate metadata in payload
serializer = serializers.FileUploadSerializer(data=request.data)
if not serializer.is_valid():
return drf_response.Response(
serializer.errors, status=status.HTTP_400_BAD_REQUEST
)
serializer.is_valid(raise_exception=True)

# Extract the file extension from the original filename
file = serializer.validated_data["file"]
extension = os.path.splitext(file.name)[1]
Expand Down Expand Up @@ -531,6 +546,63 @@ def retrieve_auth(self, request, *args, **kwargs):
request = utils.generate_s3_authorization_headers(f"{pk:s}/{attachment_key:s}")
return drf_response.Response("authorized", headers=request.headers, status=200)

@decorators.action(
detail=True,
methods=["post"],
name="Apply a transformation action on a piece of text with AI",
url_path="ai-transform",
throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle],
)
def ai_transform(self, request, *args, **kwargs):
"""
POST /api/v1.0/documents/<resource_id>/ai-transform
with expected data:
- text: str
- action: str [prompt, correct, rephrase, summarize]
Return JSON response with the processed text.
"""
# Check permissions first
self.get_object()

serializer = serializers.AITransformSerializer(data=request.data)
serializer.is_valid(raise_exception=True)

text = serializer.validated_data["text"]
action = serializer.validated_data["action"]

response = AIService().transform(text, action)

return drf_response.Response(response, status=status.HTTP_200_OK)

@decorators.action(
detail=True,
methods=["post"],
name="Translate a piece of text with AI",
serializer_class=serializers.AITranslateSerializer,
url_path="ai-translate",
throttle_classes=[utils.AIDocumentRateThrottle, utils.AIUserRateThrottle],
)
def ai_translate(self, request, *args, **kwargs):
"""
POST /api/v1.0/documents/<resource_id>/ai-translate
with expected data:
- text: str
- language: str [settings.LANGUAGES]
Return JSON response with the translated text.
"""
# Check permissions first
self.get_object()

serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)

text = serializer.validated_data["text"]
language = serializer.validated_data["language"]

response = AIService().translate(text, language)

return drf_response.Response(response, status=status.HTTP_200_OK)


class DocumentAccessViewSet(
ResourceAccessViewsetMixin,
Expand Down
14 changes: 5 additions & 9 deletions src/backend/core/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@
Core application enums declaration
"""

from django.conf import global_settings, settings
from django.conf import global_settings
from django.utils.translation import gettext_lazy as _

# Django sets `LANGUAGES` by default with all supported languages. We can use it for
# the choice of languages which should not be limited to the few languages active in
# the app.
# In Django's code base, `LANGUAGES` is set by default with all supported languages.
# We can use it for the choice of languages which should not be limited to the few languages
# active in the app.
# pylint: disable=no-member
ALL_LANGUAGES = getattr(
settings,
"ALL_LANGUAGES",
[(language, _(name)) for language, name in global_settings.LANGUAGES],
)
ALL_LANGUAGES = {language: _(name) for language, name in global_settings.LANGUAGES}
2 changes: 2 additions & 0 deletions src/backend/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,8 @@ def get_abilities(self, user):
can_get = bool(roles)

return {
"ai_transform": is_owner_or_admin or is_editor,
"ai_translate": is_owner_or_admin or is_editor,
"attachment_upload": is_owner_or_admin or is_editor,
"destroy": RoleChoices.OWNER in roles,
"link_configuration": is_owner_or_admin,
Expand Down
Empty file.
Loading
Loading