-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature: migrate DatashareTaskClient to datashare-python
- Loading branch information
Showing
5 changed files
with
325 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import uuid | ||
from typing import Any, Dict, Optional | ||
|
||
from icij_common.pydantic_utils import jsonable_encoder | ||
from icij_worker import Task, TaskError, TaskState | ||
from icij_worker.exceptions import UnknownTask | ||
from icij_worker.utils.http import AiohttpClient | ||
|
||
# TODO: maxRetries is not supported by java, it's automatically set to 3 | ||
_TASK_UNSUPPORTED = {"max_retries"} | ||
|
||
|
||
class DatashareTaskClient(AiohttpClient): | ||
def __init__(self, datashare_url: str, api_key: str | None = None) -> None: | ||
headers = None | ||
if api_key is not None: | ||
headers = {"Authorization": f"Bearer {api_key}"} | ||
super().__init__(datashare_url, headers=headers) | ||
|
||
async def __aenter__(self): | ||
await super().__aenter__() | ||
if "Authorization" not in self._headers: | ||
async with self._get("/settings") as res: | ||
# SimpleCookie doesn't seem to parse DS cookie so we perform some dirty | ||
# hack here | ||
session_id = [ | ||
item | ||
for item in res.headers["Set-Cookie"].split("; ") | ||
if "session_id" in item | ||
] | ||
if len(session_id) != 1: | ||
raise ValueError("Invalid cookie") | ||
k, v = session_id[0].split("=") | ||
self._session.cookie_jar.update_cookies({k: v}) | ||
|
||
async def create_task( | ||
self, | ||
name: str, | ||
args: Dict[str, Any], | ||
*, | ||
id_: Optional[str] = None, | ||
group: Optional[str] = None, | ||
) -> str: | ||
if id_ is None: | ||
id_ = _generate_task_id(name) | ||
task = Task.create(task_id=id_, task_name=name, args=args) | ||
task = jsonable_encoder(task, exclude=_TASK_UNSUPPORTED, exclude_unset=True) | ||
task.pop("createdAt") | ||
url = f"/api/task/{id_}" | ||
if group is not None: | ||
if not isinstance(group, str): | ||
raise TypeError(f"expected group to be a string found {group}") | ||
url += f"?group={group}" | ||
async with self._put(url, json=task) as res: | ||
task_res = await res.json() | ||
return task_res["taskId"] | ||
|
||
async def get_task(self, id_: str) -> Task: | ||
url = f"/api/task/{id_}" | ||
async with self._get(url) as res: | ||
task = await res.json() | ||
if task is None: | ||
raise UnknownTask(id_) | ||
# TODO: align Java on Python here... it's not a good idea to store results | ||
# inside tasks since result can be quite large and we may want to get the task | ||
# metadata without having to deal with the large task results... | ||
task = _ds_to_icij_worker_task(task) | ||
task = Task(**task) | ||
return task | ||
|
||
async def get_tasks(self) -> list[Task]: | ||
url = "/api/task/all" | ||
async with self._get(url) as res: | ||
tasks = await res.json() | ||
# TODO: align Java on Python here... it's not a good idea to store results | ||
# inside tasks since result can be quite large and we may want to get the task | ||
# metadata without having to deal with the large task results... | ||
tasks = (_ds_to_icij_worker_task(t) for t in tasks) | ||
tasks = [Task(**task) for task in tasks] | ||
return tasks | ||
|
||
async def get_task_state(self, id_: str) -> TaskState: | ||
return (await self.get_task(id_)).state | ||
|
||
async def get_task_result(self, id_: str) -> Any: | ||
url = f"/api/task/{id_}/results" | ||
async with self._get(url) as res: | ||
task_res = await res.json() | ||
return task_res | ||
|
||
async def get_task_error(self, id_: str) -> TaskError: | ||
url = f"/api/task/{id_}" | ||
async with self._get(url) as res: | ||
task = await res.json() | ||
if task is None: | ||
raise UnknownTask(id_) | ||
task_state = TaskState[task["state"]] | ||
if task_state != TaskState.ERROR: | ||
msg = f"can't find error for task {id_} in state {task_state}" | ||
raise ValueError(msg) | ||
error = TaskError(**task["error"]) | ||
return error | ||
|
||
async def delete(self, id_: str): | ||
url = f"/api/task/{id_}" | ||
async with self._delete(url): | ||
pass | ||
|
||
async def delete_all_tasks(self): | ||
for t in await self.get_tasks(): | ||
await self.delete(t.id) | ||
|
||
|
||
def _generate_task_id(task_name: str) -> str: | ||
return f"{task_name}-{uuid.uuid4()}" | ||
|
||
|
||
_JAVA_TASK_ATTRIBUTES = ["result", "error"] | ||
|
||
|
||
def _ds_to_icij_worker_task(task: dict) -> dict: | ||
for k in _JAVA_TASK_ATTRIBUTES: | ||
task.pop(k, None) | ||
return task |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
import uuid | ||
from contextlib import asynccontextmanager | ||
from datetime import datetime | ||
from typing import Any | ||
from unittest.mock import AsyncMock | ||
|
||
from aiohttp.typedefs import StrOrURL | ||
from icij_worker import Task, TaskError, TaskState | ||
from icij_worker.objects import StacktraceItem | ||
|
||
from datashare_python.task_client import DatashareTaskClient | ||
|
||
|
||
async def test_task_client_create_task(monkeypatch): | ||
# Given | ||
datashare_url = "http://some-url" | ||
api_key = "some-api-key" | ||
task_name = "hello" | ||
task_id = f"{task_name}-{uuid.uuid4()}" | ||
args = {"greeted": "world"} | ||
group = "PYTHON" | ||
|
||
@asynccontextmanager | ||
async def _put_and_assert(_, url: StrOrURL, *, data: Any = None, **kwargs: Any): | ||
assert url == f"/api/task/{task_id}?group={group}" | ||
expected_task = { | ||
"@type": "Task", | ||
"id": task_id, | ||
"state": "CREATED", | ||
"name": "hello", | ||
"args": {"greeted": "world"}, | ||
} | ||
expected_data = expected_task | ||
assert data is None | ||
json_data = kwargs.pop("json") | ||
assert not kwargs | ||
assert json_data == expected_data | ||
mocked_res = AsyncMock() | ||
mocked_res.json.return_value = {"taskId": task_id} | ||
yield mocked_res | ||
|
||
monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._put", _put_and_assert) | ||
|
||
task_client = DatashareTaskClient(datashare_url, api_key=api_key) | ||
async with task_client: | ||
# When | ||
t_id = await task_client.create_task(task_name, args, id_=task_id, group=group) | ||
assert t_id == task_id | ||
|
||
|
||
async def test_task_client_get_task(monkeypatch): | ||
# Given | ||
datashare_url = "http://some-url" | ||
api_key = "some-api-key" | ||
task_name = "hello" | ||
task_id = f"{task_name}-{uuid.uuid4()}" | ||
|
||
@asynccontextmanager | ||
async def _get_and_assert( | ||
_, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any | ||
): | ||
assert url == f"/api/task/{task_id}" | ||
task = { | ||
"@type": "Task", | ||
"id": task_id, | ||
"state": "CREATED", | ||
"createdAt": datetime.now(), | ||
"name": "hello", | ||
"args": {"greeted": "world"}, | ||
} | ||
assert allow_redirects | ||
assert not kwargs | ||
mocked_res = AsyncMock() | ||
mocked_res.json.return_value = task | ||
yield mocked_res | ||
|
||
monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._get", _get_and_assert) | ||
|
||
task_client = DatashareTaskClient(datashare_url, api_key=api_key) | ||
async with task_client: | ||
# When | ||
task = await task_client.get_task(task_id) | ||
assert isinstance(task, Task) | ||
|
||
|
||
async def test_task_client_get_task_state(monkeypatch): | ||
# Given | ||
datashare_url = "http://some-url" | ||
api_key = "some-api-key" | ||
task_name = "hello" | ||
task_id = f"{task_name}-{uuid.uuid4()}" | ||
|
||
@asynccontextmanager | ||
async def _get_and_assert( | ||
_, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any | ||
): | ||
assert url == f"/api/task/{task_id}" | ||
task = { | ||
"@type": "Task", | ||
"id": task_id, | ||
"state": "DONE", | ||
"createdAt": datetime.now(), | ||
"completedAt": datetime.now(), | ||
"name": "hello", | ||
"args": {"greeted": "world"}, | ||
"result": "hellow world", | ||
} | ||
assert allow_redirects | ||
assert not kwargs | ||
mocked_res = AsyncMock() | ||
mocked_res.json.return_value = task | ||
yield mocked_res | ||
|
||
monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._get", _get_and_assert) | ||
|
||
task_client = DatashareTaskClient(datashare_url, api_key=api_key) | ||
async with task_client: | ||
# When | ||
res = await task_client.get_task_state(task_id) | ||
assert res == TaskState.DONE | ||
|
||
|
||
async def test_task_client_get_task_result(monkeypatch): | ||
# Given | ||
datashare_url = "http://some-url" | ||
api_key = "some-api-key" | ||
task_name = "hello" | ||
task_id = f"{task_name}-{uuid.uuid4()}" | ||
|
||
@asynccontextmanager | ||
async def _get_and_assert( | ||
_, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any | ||
): | ||
assert url == f"/api/task/{task_id}/results" | ||
assert allow_redirects | ||
assert not kwargs | ||
mocked_res = AsyncMock() | ||
mocked_res.json.return_value = "hellow world" | ||
yield mocked_res | ||
|
||
monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._get", _get_and_assert) | ||
|
||
task_client = DatashareTaskClient(datashare_url, api_key=api_key) | ||
async with task_client: | ||
# When | ||
res = await task_client.get_task_result(task_id) | ||
assert res == "hellow world" | ||
|
||
|
||
async def test_task_client_get_task_error(monkeypatch): | ||
# Given | ||
datashare_url = "http://some-url" | ||
api_key = "some-api-key" | ||
task_name = "hello" | ||
task_id = f"{task_name}-{uuid.uuid4()}" | ||
|
||
@asynccontextmanager | ||
async def _get_and_assert( | ||
_, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any | ||
): | ||
assert url == f"/api/task/{task_id}" | ||
task = { | ||
"@type": "Task", | ||
"id": task_id, | ||
"state": "ERROR", | ||
"createdAt": datetime.now(), | ||
"completedAt": datetime.now(), | ||
"name": "hello", | ||
"args": {"greeted": "world"}, | ||
"error": { | ||
"@type": "TaskError", | ||
"name": "SomeError", | ||
"message": "some error found", | ||
"cause": "i'm the culprit", | ||
"stacktrace": [{"lineno": 666, "file": "some_file.py", "name": "err"}], | ||
}, | ||
} | ||
assert allow_redirects | ||
assert not kwargs | ||
mocked_res = AsyncMock() | ||
mocked_res.json.return_value = task | ||
yield mocked_res | ||
|
||
monkeypatch.setattr("icij_worker.utils.http.AiohttpClient._get", _get_and_assert) | ||
|
||
task_client = DatashareTaskClient(datashare_url, api_key=api_key) | ||
async with task_client: | ||
# When | ||
error = await task_client.get_task_error(task_id) | ||
expected_error = TaskError( | ||
name="SomeError", | ||
message="some error found", | ||
cause="i'm the culprit", | ||
stacktrace=[StacktraceItem(name="err", file="some_file.py", lineno=666)], | ||
) | ||
assert error == expected_error |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters