Skip to content

Commit

Permalink
feature: migrate DatashareTaskClient to datashare-python
Browse files Browse the repository at this point in the history
  • Loading branch information
ClemDoum committed Dec 20, 2024
1 parent 4ccc79e commit 5f866f0
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 26 deletions.
6 changes: 3 additions & 3 deletions datashare_python/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def to_es_client(self, address: str | None = None) -> "ESClient":
)
return client

def to_task_client(self) -> "DSTaskClient":
from datashare_python.utils import DSTaskClient
def to_task_client(self) -> "DatashareTaskClient":
from datashare_python.task_client import DatashareTaskClient

return DSTaskClient(self.ds_url)
return DatashareTaskClient(self.ds_url)
124 changes: 124 additions & 0 deletions datashare_python/task_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import uuid
from typing import Any, Dict, Optional

from icij_common.pydantic_utils import jsonable_encoder
from icij_worker import Task, TaskError, TaskState
from icij_worker.exceptions import UnknownTask
from icij_worker.utils.http import AiohttpClient

# TODO: maxRetries is not supported by java, it's automatically set to 3
_TASK_UNSUPPORTED = {"max_retries"}


class DatashareTaskClient(AiohttpClient):
def __init__(self, datashare_url: str, api_key: str | None = None) -> None:
headers = None
if api_key is not None:
headers = {"Authorization": f"Bearer {api_key}"}
super().__init__(datashare_url, headers=headers)

async def __aenter__(self):
await super().__aenter__()
if "Authorization" not in self._headers:
async with self._get("/settings") as res:
# SimpleCookie doesn't seem to parse DS cookie so we perform some dirty
# hack here
session_id = [
item
for item in res.headers["Set-Cookie"].split("; ")
if "session_id" in item
]
if len(session_id) != 1:
raise ValueError("Invalid cookie")
k, v = session_id[0].split("=")
self._session.cookie_jar.update_cookies({k: v})

async def create_task(
self,
name: str,
args: Dict[str, Any],
*,
id_: Optional[str] = None,
group: Optional[str] = None,
) -> str:
if id_ is None:
id_ = _generate_task_id(name)
task = Task.create(task_id=id_, task_name=name, args=args)
task = jsonable_encoder(task, exclude=_TASK_UNSUPPORTED, exclude_unset=True)
task.pop("createdAt")
url = f"/api/task/{id_}"
if group is not None:
if not isinstance(group, str):
raise TypeError(f"expected group to be a string found {group}")
url += f"?group={group}"
async with self._put(url, json=task) as res:
task_res = await res.json()
return task_res["taskId"]

async def get_task(self, id_: str) -> Task:
url = f"/api/task/{id_}"
async with self._get(url) as res:
task = await res.json()
if task is None:
raise UnknownTask(id_)
# TODO: align Java on Python here... it's not a good idea to store results
# inside tasks since result can be quite large and we may want to get the task
# metadata without having to deal with the large task results...
task = _ds_to_icij_worker_task(task)
task = Task(**task)
return task

async def get_tasks(self) -> list[Task]:
url = "/api/task/all"
async with self._get(url) as res:
tasks = await res.json()
# TODO: align Java on Python here... it's not a good idea to store results
# inside tasks since result can be quite large and we may want to get the task
# metadata without having to deal with the large task results...
tasks = (_ds_to_icij_worker_task(t) for t in tasks)
tasks = [Task(**task) for task in tasks]
return tasks

async def get_task_state(self, id_: str) -> TaskState:
return (await self.get_task(id_)).state

async def get_task_result(self, id_: str) -> Any:
url = f"/api/task/{id_}/results"
async with self._get(url) as res:
task_res = await res.json()
return task_res

async def get_task_error(self, id_: str) -> TaskError:
url = f"/api/task/{id_}"
async with self._get(url) as res:
task = await res.json()
if task is None:
raise UnknownTask(id_)
task_state = TaskState[task["state"]]
if task_state != TaskState.ERROR:
msg = f"can't find error for task {id_} in state {task_state}"
raise ValueError(msg)
error = TaskError(**task["error"])
return error

async def delete(self, id_: str):
url = f"/api/task/{id_}"
async with self._delete(url):
pass

async def delete_all_tasks(self):
for t in await self.get_tasks():
await self.delete(t.id)


def _generate_task_id(task_name: str) -> str:
return f"{task_name}-{uuid.uuid4()}"


_JAVA_TASK_ATTRIBUTES = ["result", "error"]


def _ds_to_icij_worker_task(task: dict) -> dict:
for k in _JAVA_TASK_ATTRIBUTES:
task.pop(k, None)
return task
4 changes: 2 additions & 2 deletions datashare_python/tests/tasks/test_translate_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from icij_common.es import ESClient

from datashare_python.objects import Document
from datashare_python.task_client import DatashareTaskClient
from datashare_python.tasks import create_translation_tasks
from datashare_python.tests.conftest import TEST_PROJECT
from datashare_python.utils import DSTaskClient

logger = logging.getLogger(__name__)

Expand All @@ -21,7 +21,7 @@ async def _progress(p: float):
async def test_create_translation_tasks_integration(
populate_es: List[Document], # pylint: disable=unused-argument
test_es_client: ESClient,
test_task_client: DSTaskClient,
test_task_client: DatashareTaskClient,
):
# Given
es_client = test_es_client
Expand Down
196 changes: 196 additions & 0 deletions datashare_python/tests/test_task_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import uuid
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Any
from unittest.mock import AsyncMock

from aiohttp.typedefs import StrOrURL
from icij_worker import Task, TaskError, TaskState
from icij_worker.objects import StacktraceItem

from datashare_python.task_client import DatashareTaskClient


async def test_task_client_create_task(monkeypatch):
# Given
datashare_url = "http://some-url"
api_key = "some-api-key"
task_name = "hello"
task_id = f"{task_name}-{uuid.uuid4()}"
args = {"greeted": "world"}
group = "PYTHON"

@asynccontextmanager
async def _put_and_assert(_, url: StrOrURL, *, data: Any = None, **kwargs: Any):
assert url == f"/api/task/{task_id}?group={group}"
expected_task = {
"@type": "Task",
"id": task_id,
"state": "CREATED",
"name": "hello",
"args": {"greeted": "world"},
}
expected_data = expected_task
assert data is None
json_data = kwargs.pop("json")
assert not kwargs
assert json_data == expected_data
mocked_res = AsyncMock()
mocked_res.json.return_value = {"taskId": task_id}
yield mocked_res

monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._put", _put_and_assert)

task_client = DatashareTaskClient(datashare_url, api_key=api_key)
async with task_client:
# When
t_id = await task_client.create_task(task_name, args, id_=task_id, group=group)
assert t_id == task_id


async def test_task_client_get_task(monkeypatch):
# Given
datashare_url = "http://some-url"
api_key = "some-api-key"
task_name = "hello"
task_id = f"{task_name}-{uuid.uuid4()}"

@asynccontextmanager
async def _get_and_assert(
_, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
):
assert url == f"/api/task/{task_id}"
task = {
"@type": "Task",
"id": task_id,
"state": "CREATED",
"createdAt": datetime.now(),
"name": "hello",
"args": {"greeted": "world"},
}
assert allow_redirects
assert not kwargs
mocked_res = AsyncMock()
mocked_res.json.return_value = task
yield mocked_res

monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._get", _get_and_assert)

task_client = DatashareTaskClient(datashare_url, api_key=api_key)
async with task_client:
# When
task = await task_client.get_task(task_id)
assert isinstance(task, Task)


async def test_task_client_get_task_state(monkeypatch):
# Given
datashare_url = "http://some-url"
api_key = "some-api-key"
task_name = "hello"
task_id = f"{task_name}-{uuid.uuid4()}"

@asynccontextmanager
async def _get_and_assert(
_, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
):
assert url == f"/api/task/{task_id}"
task = {
"@type": "Task",
"id": task_id,
"state": "DONE",
"createdAt": datetime.now(),
"completedAt": datetime.now(),
"name": "hello",
"args": {"greeted": "world"},
"result": "hellow world",
}
assert allow_redirects
assert not kwargs
mocked_res = AsyncMock()
mocked_res.json.return_value = task
yield mocked_res

monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._get", _get_and_assert)

task_client = DatashareTaskClient(datashare_url, api_key=api_key)
async with task_client:
# When
res = await task_client.get_task_state(task_id)
assert res == TaskState.DONE


async def test_task_client_get_task_result(monkeypatch):
# Given
datashare_url = "http://some-url"
api_key = "some-api-key"
task_name = "hello"
task_id = f"{task_name}-{uuid.uuid4()}"

@asynccontextmanager
async def _get_and_assert(
_, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
):
assert url == f"/api/task/{task_id}/results"
assert allow_redirects
assert not kwargs
mocked_res = AsyncMock()
mocked_res.json.return_value = "hellow world"
yield mocked_res

monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._get", _get_and_assert)

task_client = DatashareTaskClient(datashare_url, api_key=api_key)
async with task_client:
# When
res = await task_client.get_task_result(task_id)
assert res == "hellow world"


async def test_task_client_get_task_error(monkeypatch):
# Given
datashare_url = "http://some-url"
api_key = "some-api-key"
task_name = "hello"
task_id = f"{task_name}-{uuid.uuid4()}"

@asynccontextmanager
async def _get_and_assert(
_, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any
):
assert url == f"/api/task/{task_id}"
task = {
"@type": "Task",
"id": task_id,
"state": "ERROR",
"createdAt": datetime.now(),
"completedAt": datetime.now(),
"name": "hello",
"args": {"greeted": "world"},
"error": {
"@type": "TaskError",
"name": "SomeError",
"message": "some error found",
"cause": "i'm the culprit",
"stacktrace": [{"lineno": 666, "file": "some_file.py", "name": "err"}],
},
}
assert allow_redirects
assert not kwargs
mocked_res = AsyncMock()
mocked_res.json.return_value = task
yield mocked_res

monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._get", _get_and_assert)

task_client = DatashareTaskClient(datashare_url, api_key=api_key)
async with task_client:
# When
error = await task_client.get_task_error(task_id)
expected_error = TaskError(
name="SomeError",
message="some error found",
cause="i'm the culprit",
stacktrace=[StacktraceItem(name="err", file="some_file.py", lineno=666)],
)
assert error == expected_error
21 changes: 0 additions & 21 deletions datashare_python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from itertools import islice
from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterable, TypeVar

from icij_worker.ds_task_client import DatashareTaskClient

T = TypeVar("T")

Predicate = Callable[[T], bool] | Callable[[T], Awaitable[bool]]
Expand Down Expand Up @@ -69,22 +67,3 @@ async def remainder_iterator():
yield elm

return true_iterator(), remainder_iterator()


class DSTaskClient(DatashareTaskClient):

async def __aenter__(self):
await super().__aenter__()

async with self._get("/settings") as res:
# SimpleCookie doesn't seem to parse DS cookie so we perform some dirty
# hack here
session_id = [
item
for item in res.headers["Set-Cookie"].split("; ")
if "session_id" in item
]
if len(session_id) != 1:
raise ValueError("Invalid cookie")
k, v = session_id[0].split("=")
self._session.cookie_jar.update_cookies({k: v})

0 comments on commit 5f866f0

Please sign in to comment.