Skip to content

Commit

Permalink
Merge pull request #18 from ai-forever/feature/check_credentials
Browse files Browse the repository at this point in the history
Check credentials and log errors from server
  • Loading branch information
Rai220 authored May 2, 2024
2 parents a6bcae4 + 5134c5a commit 5a142d0
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "gigachat"
version = "0.1.23"
version = "0.1.24"
description = "GigaChat. Python-library for GigaChain and LangChain"
authors = ["Konstantin Krestnikov <[email protected]>", "Sergey Malyshev <[email protected]>"]
license = "MIT"
Expand Down
16 changes: 16 additions & 0 deletions src/gigachat/api/post_auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import base64
import binascii
import logging
import uuid
from http import HTTPStatus
from typing import Any, Dict
Expand All @@ -8,6 +11,8 @@
from gigachat.exceptions import AuthenticationError, ResponseError
from gigachat.models import AccessToken

_logger = logging.getLogger(__name__)


def _get_kwargs(*, url: str, credentials: str, scope: str) -> Dict[str, Any]:
headers = {
Expand All @@ -32,13 +37,24 @@ def _build_response(response: httpx.Response) -> AccessToken:
raise ResponseError(response.url, response.status_code, response.content, response.headers)


def _validate_credentials(credentials: str) -> None:
try:
base64.b64decode(credentials, validate=True)
except (ValueError, binascii.Error):
_logger.warning(
"Invalid credentials format. Please use only base64 credentials (Authorization data, not client secret!)"
)


def sync(client: httpx.Client, *, url: str, credentials: str, scope: str) -> AccessToken:
_validate_credentials(credentials)
kwargs = _get_kwargs(url=url, credentials=credentials, scope=scope)
response = client.request(**kwargs)
return _build_response(response)


async def asyncio(client: httpx.AsyncClient, *, url: str, credentials: str, scope: str) -> AccessToken:
_validate_credentials(credentials)
kwargs = _get_kwargs(url=url, credentials=credentials, scope=scope)
response = await client.request(**kwargs)
return _build_response(response)
15 changes: 12 additions & 3 deletions src/gigachat/api/stream_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,18 @@ def _check_response(response: httpx.Response) -> None:
if response.status_code == HTTPStatus.OK:
_check_content_type(response)
elif response.status_code == HTTPStatus.UNAUTHORIZED:
raise AuthenticationError(response.url, response.status_code, b"", response.headers)
raise AuthenticationError(response.url, response.status_code, response.read(), response.headers)
else:
raise ResponseError(response.url, response.status_code, b"", response.headers)
raise ResponseError(response.url, response.status_code, response.read(), response.headers)


async def _acheck_response(response: httpx.Response) -> None:
if response.status_code == HTTPStatus.OK:
_check_content_type(response)
elif response.status_code == HTTPStatus.UNAUTHORIZED:
raise AuthenticationError(response.url, response.status_code, await response.aread(), response.headers)
else:
raise ResponseError(response.url, response.status_code, await response.aread(), response.headers)


def sync(
Expand All @@ -75,7 +84,7 @@ async def asyncio(
) -> AsyncIterator[ChatCompletionChunk]:
kwargs = _get_kwargs(chat=chat, access_token=access_token)
async with client.stream(**kwargs) as response:
_check_response(response)
await _acheck_response(response)
async for line in response.aiter_lines():
if chunk := _parse_chunk(line):
yield chunk
3 changes: 2 additions & 1 deletion src/gigachat/models/messages.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional

from gigachat.models.function_call import FunctionCall
from gigachat.models.messages_role import MessagesRole
Expand All @@ -14,6 +14,7 @@ class Messages(BaseModel):
"""Текст сообщения"""
function_call: Optional[FunctionCall] = None
"""Вызов функции"""
id: Optional[Any] = None # noqa: A003

class Config:
use_enum_values = True

0 comments on commit 5a142d0

Please sign in to comment.