Skip to content

Commit

Permalink
feat: add load completion file in the sdk module (#215)
Browse files Browse the repository at this point in the history
* add load completion file in the sdk module

* fix bind completion file process

* ruff check

* ruff check
  • Loading branch information
dtria91 authored Dec 18, 2024
1 parent f3b750e commit 0f56bc9
Show file tree
Hide file tree
Showing 19 changed files with 728 additions and 35 deletions.
9 changes: 9 additions & 0 deletions api/app/models/dataset_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ class FileReference(BaseModel):
)


class FileCompletion(BaseModel):
file_url: str

model_config = ConfigDict(
populate_by_name=True,
alias_generator=to_camel,
)


class OrderType(str, Enum):
ASC = 'asc'
DESC = 'desc'
11 changes: 11 additions & 0 deletions api/app/routes/upload_dataset_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from app.models.dataset_dto import (
CompletionDatasetDTO,
CurrentDatasetDTO,
FileCompletion,
FileReference,
OrderType,
ReferenceDatasetDTO,
Expand Down Expand Up @@ -75,6 +76,16 @@ def upload_completion_file(
) -> CompletionDatasetDTO:
return file_service.upload_completion_file(model_uuid, json_file)

@router.post(
'/{model_uuid}/completion/bind',
status_code=status.HTTP_200_OK,
response_model=CompletionDatasetDTO,
)
def bind_completion_file(
model_uuid: UUID, file_completion: FileCompletion
) -> CompletionDatasetDTO:
return file_service.bind_completion_file(model_uuid, file_completion)

@router.get(
'/{model_uuid}/reference',
status_code=200,
Expand Down
51 changes: 51 additions & 0 deletions api/app/services/file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from app.models.dataset_dto import (
CompletionDatasetDTO,
CurrentDatasetDTO,
FileCompletion,
FileReference,
OrderType,
ReferenceDatasetDTO,
Expand Down Expand Up @@ -398,6 +399,56 @@ def upload_completion_file(
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

def bind_completion_file(
self, model_uuid: UUID, file_completion: FileCompletion
) -> CompletionDatasetDTO:
model_out = self.model_svc.get_model_by_uuid(model_uuid)
if not model_out:
logger.error('Model %s not found', model_uuid)
raise ModelNotFoundError(f'Model {model_uuid} not found')
try:
url_parts = file_completion.file_url.replace('s3://', '').split('/')
self.s3_client.head_object(Bucket=url_parts[0], Key='/'.join(url_parts[1:]))

inserted_file = self.completion_dataset_dao.insert_completion_dataset(
CompletionDataset(
uuid=uuid4(),
model_uuid=model_uuid,
path=file_completion.file_url,
date=datetime.datetime.now(tz=datetime.UTC),
status=JobStatus.IMPORTING,
)
)
logger.debug('File %s has been correctly stored in the db', inserted_file)

spark_config = get_config().spark_config
self.__submit_app(
app_name=str(model_out.uuid),
app_path=spark_config.spark_completion_app_path,
app_arguments=[
file_completion.file_url.replace('s3://', 's3a://'),
str(inserted_file.uuid),
CompletionDatasetMetrics.__tablename__,
CompletionDataset.__tablename__,
],
)

return CompletionDatasetDTO.from_completion_dataset(inserted_file)

except NoCredentialsError as nce:
raise HTTPException(
status_code=500, detail='S3 credentials not available'
) from nce
except ClientError as e:
if e.response['Error']['Code'] == '404':
raise HTTPException(
status_code=404,
detail=f'File {file_completion.file_url} not exists',
) from None
raise HTTPException(status_code=500, detail=str(e)) from e
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

def get_all_reference_datasets_by_model_uuid_paginated(
self,
model_uuid: UUID,
Expand Down
3 changes: 2 additions & 1 deletion sdk/radicalbit_platform_sdk/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .model_current_dataset import ModelCurrentDataset
from .model_reference_dataset import ModelReferenceDataset
from .model_completion_dataset import ModelCompletionDataset
from .model import Model

__all__ = ['Model', 'ModelCurrentDataset', 'ModelReferenceDataset']
__all__ = ['Model', 'ModelCurrentDataset', 'ModelReferenceDataset', 'ModelCompletionDataset']
123 changes: 121 additions & 2 deletions sdk/radicalbit_platform_sdk/apis/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from io import BytesIO
import json
import os
from typing import List, Optional
from typing import Dict, List, Optional
from uuid import UUID

import boto3
Expand All @@ -8,12 +10,18 @@
from pydantic import TypeAdapter, ValidationError
import requests

from radicalbit_platform_sdk.apis import ModelCurrentDataset, ModelReferenceDataset
from radicalbit_platform_sdk.apis import (
ModelCompletionDataset,
ModelCurrentDataset,
ModelReferenceDataset,
)
from radicalbit_platform_sdk.commons import invoke
from radicalbit_platform_sdk.errors import ClientError
from radicalbit_platform_sdk.models import (
AwsCredentials,
ColumnDefinition,
CompletionFileUpload,
CompletionResponses,
CurrentFileUpload,
DataType,
FileReference,
Expand All @@ -24,6 +32,7 @@
OutputType,
ReferenceFileUpload,
)
from radicalbit_platform_sdk.models.file_upload_result import FileCompletion


class Model:
Expand Down Expand Up @@ -504,7 +513,117 @@ def __callback(response: requests.Response) -> ModelCurrentDataset:
data=file_ref.model_dump_json(),
)

def load_completion_dataset(
self,
file_name: str,
bucket: str,
object_name: Optional[str] = None,
aws_credentials: Optional[AwsCredentials] = None,
) -> ModelCompletionDataset:
"""Upload completion dataset to an S3 bucket and then bind it inside the platform.
Raises `ClientError` in case S3 upload fails.
:param file_name: The name of the completion file.
:param bucket: The name of the S3 bucket.
:param object_name: The optional name of the object uploaded to S3. Default value is None.
:param aws_credentials: AWS credentials used to connect to S3 bucket. Default value is None.
:return: An instance of `ModelCompletionDataset` representing the completion dataset
"""

try:
with open(file_name, encoding='utf-8') as f:
raw_json = json.load(f)
validated_json_bytes = self.__validate_json(raw_json)
except Exception as e:
raise ClientError(
f'Failed to validate JSON file {file_name}: {str(e)}'
) from e

if object_name is None:
object_name = f'{self.__uuid}/completion/{os.path.basename(file_name)}'

try:
s3_client = boto3.client(
's3',
aws_access_key_id=(
None if aws_credentials is None else aws_credentials.access_key_id
),
aws_secret_access_key=(
None
if aws_credentials is None
else aws_credentials.secret_access_key
),
region_name=(
None if aws_credentials is None else aws_credentials.default_region
),
endpoint_url=(
None
if aws_credentials is None
else (
None
if aws_credentials.endpoint_url is None
else aws_credentials.endpoint_url
)
),
)

s3_client.upload_fileobj(
validated_json_bytes,
bucket,
object_name,
ExtraArgs={
'Metadata': {
'model_uuid': str(self.__uuid),
'model_name': self.__name,
'file_type': 'completion',
}
},
)
except BotoClientError as e:
raise ClientError(
f'Unable to upload file {file_name} to remote storage: {e}'
) from e
return self.__bind_completion_dataset(f's3://{bucket}/{object_name}')

def __bind_completion_dataset(
self,
dataset_url: str,
) -> ModelCompletionDataset:
def __callback(response: requests.Response) -> ModelCompletionDataset:
try:
response = CompletionFileUpload.model_validate(response.json())
return ModelCompletionDataset(
self.__base_url, self.__uuid, self.__model_type, response
)
except ValidationError as e:
raise ClientError(f'Unable to parse response: {response.text}') from e

file_completion = FileCompletion(
file_url=dataset_url,
)

return invoke(
method='POST',
url=f'{self.__base_url}/api/models/{str(self.__uuid)}/completion/bind',
valid_response_code=200,
func=__callback,
data=file_completion.model_dump_json(),
)

def __required_headers(self) -> List[str]:
model_columns = self.__features + self.__outputs.output
model_columns.append(self.__target)
return [model_column.name for model_column in model_columns]

@staticmethod
def __validate_json(json_data: List[Dict]) -> BytesIO:
try:
validated_data = CompletionResponses.model_validate(json_data)
return BytesIO(validated_data.model_dump_json().encode())
except ValidationError as e:
raise ClientError(f'JSON validation error: {str(e)}') from e
except Exception as e:
raise ClientError(
f'Unexpected error during JSON validation: {str(e)}'
) from e
103 changes: 103 additions & 0 deletions sdk/radicalbit_platform_sdk/apis/model_completion_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Optional
from uuid import UUID

from pydantic import ValidationError
import requests

from radicalbit_platform_sdk.commons import invoke
from radicalbit_platform_sdk.errors import ClientError
from radicalbit_platform_sdk.models import (
CompletionFileUpload,
CompletionTextGenerationModelQuality,
JobStatus,
ModelQuality,
ModelType,
)


class ModelCompletionDataset:
def __init__(
self,
base_url: str,
model_uuid: UUID,
model_type: ModelType,
upload: CompletionFileUpload,
) -> None:
self.__base_url = base_url
self.__model_uuid = model_uuid
self.__model_type = model_type
self.__uuid = upload.uuid
self.__path = upload.path
self.__date = upload.date
self.__status = upload.status
self.__model_metrics = None

def uuid(self) -> UUID:
return self.__uuid

def path(self) -> str:
return self.__path

def date(self) -> str:
return self.__date

def status(self) -> str:
return self.__status

def model_quality(self) -> Optional[ModelQuality]:
"""Get model quality metrics about the completion dataset
:return: The `ModelQuality` if exists
"""

def __callback(
response: requests.Response,
) -> tuple[JobStatus, Optional[ModelQuality]]:
try:
response_json = response.json()
job_status = JobStatus(response_json['jobStatus'])
if 'modelQuality' in response_json:
match self.__model_type:
case ModelType.TEXT_GENERATION:
return (
job_status,
CompletionTextGenerationModelQuality.model_validate(
response_json['modelQuality']
),
)
case _:
raise ClientError(
'Unable to parse metrics because of not managed model type'
) from None
except KeyError as e:
raise ClientError(f'Unable to parse response: {response.text}') from e
except ValidationError as e:
raise ClientError(f'Unable to parse response: {response.text}') from e
else:
return job_status, None

match self.__status:
case JobStatus.ERROR:
self.__model_metrics = None
case JobStatus.MISSING_COMPLETION:
self.__model_metrics = None
case JobStatus.SUCCEEDED:
if self.__model_metrics is None:
_, metrics = invoke(
method='GET',
url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/completion/{str(self.__uuid)}/model-quality',
valid_response_code=200,
func=__callback,
)
self.__model_metrics = metrics
case JobStatus.IMPORTING:
status, metrics = invoke(
method='GET',
url=f'{self.__base_url}/api/models/{str(self.__model_uuid)}/completion/{str(self.__uuid)}/model-quality',
valid_response_code=200,
func=__callback,
)
self.__status = status
self.__model_metrics = metrics

return self.__model_metrics
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@ def distribution_chart(self, data: BinaryDistributionChartData) -> EChartsRawWid
if data.current_data:
assert len(data.current_data) <= 2

reference_json_data = [binary_data.model_dump() for binary_data in data.reference_data]
current_data_json = [binary_data.model_dump() for binary_data in data.current_data] if data.current_data else []
reference_json_data = [
binary_data.model_dump() for binary_data in data.reference_data
]
current_data_json = (
[binary_data.model_dump() for binary_data in data.current_data]
if data.current_data
else []
)

reference_series_data = {
'title': data.title,
Expand Down Expand Up @@ -87,7 +93,6 @@ def distribution_chart(self, data: BinaryDistributionChartData) -> EChartsRawWid
return EChartsRawWidget(option=option)

def linear_chart(self, data: BinaryLinearChartData) -> EChartsRawWidget:

reference_series_data = {
'name': 'Reference',
'type': 'line',
Expand Down
Loading

0 comments on commit 0f56bc9

Please sign in to comment.