Skip to content

Commit

Permalink
aiohttp client session seems leaking connection, init only one for As…
Browse files Browse the repository at this point in the history
…yncClient
  • Loading branch information
Xiao Li committed Jul 12, 2021
1 parent e042d23 commit 37c6798
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 46 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ https://pypi.org/project/diem/
>>> import asyncio
>>>
>>> async def main():
... client = AsyncClient(JSON_RPC_URL)
... print(await client.get_metadata())
... # Use with statement to close client after usage
... # or call client.close() when initialized without with statement
... with AsyncClient(JSON_RPC_URL) as client:
... print(await client.get_metadata())
...
>>> asyncio.run(main())

Expand Down
4 changes: 2 additions & 2 deletions src/diem/jsonrpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
>>> import asyncio
>>>
>>> async def main():
... client = AsyncClient(JSON_RPC_URL)
... print(await client.get_metadata())
... with AsyncClient(JSON_RPC_URL) as client:
... print(await client.get_metadata())
...
>>> asyncio.run(main())
Expand Down
28 changes: 18 additions & 10 deletions src/diem/jsonrpc/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import copy
import dataclasses
import google.protobuf.json_format as parser
import asyncio
import asyncio, aiohttp
import typing
import random
import functools
Expand Down Expand Up @@ -146,11 +146,20 @@ def __init__(
session_factory: typing.Callable[[], ClientSession] = ClientSession,
) -> None:
self._url: str = server_url
self._session_factory = session_factory
self._last_known_server_state: State = State(chain_id=-1, version=-1, timestamp_usecs=-1)
self._retry: Retry = retry or Retry(DEFAULT_MAX_RETRIES, DEFAULT_RETRY_DELAY, StaleResponseError)
self._rs: RequestStrategy = rs or RequestStrategy()
self._logger: Logger = logger or getLogger(__name__)
self._session: aiohttp.ClientSession = session_factory()

async def close(self) -> None:
await self._session.close()

async def __aenter__(self) -> "AsyncClient":
return self

async def __aexit__(self, *args, **kwargs) -> None: # pyre-ignore
await self.close()

# high level functions

Expand Down Expand Up @@ -571,14 +580,13 @@ async def _send_http_request(
) -> typing.Dict[str, typing.Any]:
self._logger.debug("http request body: %s", request)
headers = {"User-Agent": USER_AGENT_HTTP_HEADER}
async with self._session_factory() as session:
async with session.post(url, json=request, headers=headers) as response:
self._logger.debug("http response body: %s", response.text)
response.raise_for_status()
try:
json = await response.json()
except ValueError as e:
raise InvalidServerResponse(f"Parse response as json failed: {e}, response: {response.text}")
async with self._session.post(url, json=request, headers=headers) as response:
self._logger.debug("http response body: %s", response.text)
response.raise_for_status()
try:
json = await response.json()
except ValueError as e:
raise InvalidServerResponse(f"Parse response as json failed: {e}, response: {response.text}")

# check stable response before check jsonrpc error
try:
Expand Down
19 changes: 9 additions & 10 deletions src/diem/offchain/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0


import typing, dataclasses, uuid, math, warnings, aiohttp
import typing, dataclasses, uuid, math, warnings

from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
from cryptography.exceptions import InvalidSignature
Expand Down Expand Up @@ -144,15 +144,14 @@ async def send_request(
http_header.X_REQUEST_SENDER_ADDRESS: request_sender_address,
}
url = f"{base_url.rstrip('/')}/v2/command"
async with aiohttp.ClientSession() as session:
async with session.post(url, data=request_bytes, headers=headers) as response:
if response.status not in [200, 400]:
response.raise_for_status()

cmd_resp = _deserialize_jws(await response.read(), CommandResponseObject, public_key)
if cmd_resp.status == CommandResponseStatus.failure:
raise CommandResponseError(cmd_resp)
return cmd_resp
async with self.jsonrpc_client._session.post(url, data=request_bytes, headers=headers) as response:
if response.status not in [200, 400]:
response.raise_for_status()

cmd_resp = _deserialize_jws(await response.read(), CommandResponseObject, public_key)
if cmd_resp.status == CommandResponseStatus.failure:
raise CommandResponseError(cmd_resp)
return cmd_resp

async def process_inbound_request(self, request_sender_address: str, request_bytes: bytes) -> Command:
"""Deprecated
Expand Down
8 changes: 4 additions & 4 deletions src/diem/testing/faucet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import functools, os

from aiohttp import ClientSession
from diem import diem_types, bcs
from diem.jsonrpc.async_client import AsyncClient, Retry, TransactionExecutionFailed
from diem.testing.local_account import LocalAccount
Expand Down Expand Up @@ -103,9 +102,10 @@ async def _mint_without_retry(
}
if diem_id_domain:
params["diem_id_domain"] = diem_id_domain
async with ClientSession(raise_for_status=True) as session:
async with session.post(self._url, params=params) as response:
de = bcs.BcsDeserializer(bytes.fromhex(await response.text()))

async with self._client._session.post(self._url, params=params) as response:
response.raise_for_status()
de = bcs.BcsDeserializer(bytes.fromhex(await response.text()))

for i in range(de.deserialize_len()):
txn = de.deserialize_any(diem_types.SignedTransaction)
Expand Down
21 changes: 9 additions & 12 deletions src/diem/testing/suites/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
dmw_stub_diem_account_config,
dmw_stub_diem_account_hrp,
)
from aiohttp import ClientSession
from typing import Optional, Tuple, Dict, List, Any, Generator, Callable, AsyncGenerator, Awaitable
from dataclasses import asdict
import pytest, json, uuid, time, warnings, asyncio
Expand Down Expand Up @@ -53,10 +52,10 @@ async def target_client(diem_client: AsyncClient) -> AsyncGenerator[RestClient,


@pytest.fixture(scope="package")
def diem_client() -> AsyncClient:
client = create_client()
print("Diem JSON-RPC URL: %s" % client._url)
return client
async def diem_client() -> AsyncGenerator[AsyncClient, None]:
async with create_client() as client:
print("Diem JSON-RPC URL: %s" % client._url)
yield client


@pytest.fixture(scope="package")
Expand Down Expand Up @@ -175,13 +174,11 @@ async def send_request_json(
base_url, public_key = await diem_client.get_base_url_and_compliance_key(account_address)
if request_body is None:
request_body = jws.encode(request_json, sender_account.compliance_key.sign)
async with ClientSession() as session:
url = f"{base_url.rstrip('/')}/v2/command"
async with session.post(url, data=request_body, headers=headers) as resp:
cmd_resp_obj = offchain.jws.deserialize(
await resp.read(), offchain.CommandResponseObject, public_key.verify
)
return (resp.status, cmd_resp_obj)

url = f"{base_url.rstrip('/')}/v2/command"
async with diem_client._session.post(url, data=request_body, headers=headers) as resp:
cmd_resp_obj = offchain.jws.deserialize(await resp.read(), offchain.CommandResponseObject, public_key.verify)
return (resp.status, cmd_resp_obj)


def payment_command_request_sample(
Expand Down
6 changes: 4 additions & 2 deletions tests/jsonrpc/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from diem.jsonrpc import AsyncClient
from diem.testing import LocalAccount, Faucet, create_client, XUS, DD_ADDRESS
from typing import AsyncGenerator

import time, aiohttp, asyncio
import pytest
Expand All @@ -23,8 +24,9 @@


@pytest.fixture
def client() -> AsyncClient:
return create_client()
async def client() -> AsyncGenerator[AsyncClient, None]:
async with create_client() as client:
yield client


async def test_get_metadata(client: AsyncClient):
Expand Down
7 changes: 4 additions & 3 deletions tests/miniwallet/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from diem.jsonrpc import AsyncClient
from diem.testing import create_client, XUS
from diem.testing.miniwallet import RestClient, AppConfig, App
from typing import Generator, Tuple
from typing import Generator, Tuple, AsyncGenerator
import asyncio, pytest


Expand Down Expand Up @@ -40,8 +40,9 @@ async def stub_client(diem_client: AsyncClient) -> RestClient:


@pytest.fixture(scope="package")
def diem_client() -> AsyncClient:
return create_client()
async def diem_client() -> AsyncGenerator[AsyncClient, None]:
async with create_client() as client:
yield client


@pytest.fixture
Expand Down
3 changes: 2 additions & 1 deletion tests/miniwallet/test_diem_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from diem.testing import LocalAccount, create_client, Faucet, XUS
from diem.testing.miniwallet.app import Transaction
from diem.testing.miniwallet.app.diem_account import DiemAccount
from diem import utils
import pytest


Expand Down Expand Up @@ -41,7 +42,7 @@ async def test_ensure_account_balance_is_always_enough():
await faucet.mint(account.auth_key.hex(), 1, XUS)
da = DiemAccount(account, [], client)
account_data = await client.must_get_account(account.account_address)
amount = account_data.balances[0].amount + 1
amount = utils.balance(account_data, XUS) + 1
payee = await faucet.gen_account()
txn = await da.submit_p2p(gen_txn(payee=payee.account_identifier(), amount=amount), (b"", b""))
await client.wait_for_transaction(txn)
Expand Down

0 comments on commit 37c6798

Please sign in to comment.