Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit concurrency #329

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions adlfs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import absolute_import, division, print_function

import asyncio
import contextlib
import io
import logging
import os
Expand All @@ -12,6 +13,7 @@
import weakref
from datetime import datetime, timedelta
from glob import has_magic
from typing import Optional

from azure.core.exceptions import (
HttpResponseError,
Expand All @@ -31,6 +33,7 @@
from fsspec.utils import infer_storage_options, tokenize

from .utils import (
_nullcontext,
close_container_client,
close_credential,
close_service_client,
Expand Down Expand Up @@ -349,6 +352,12 @@ class AzureBlobFileSystem(AsyncFileSystem):
default_cache_type: string ('bytes')
If given, the default cache_type value used for "open()". Set to none if no caching
is desired. Docs in fsspec
max_concurrency : int, optional
The maximum number of BlobClient connections that can exist simultaneously for this
filesystem instance. By default, there is no limit. Setting this might be helpful if
you have a very large number of small, independent blob operations to perform. By
default a single BlobClient is created per blob, which might cause high memory usage
and clogging the asyncio event loop as many instances are created and quickly destroyed.

Pass on to fsspec:

Expand Down Expand Up @@ -410,6 +419,7 @@ def __init__(
asynchronous: bool = False,
default_fill_cache: bool = True,
default_cache_type: str = "bytes",
max_concurrency: Optional[int] = None,
**kwargs,
):
super_kwargs = {
Expand Down Expand Up @@ -438,6 +448,13 @@ def __init__(
self.blocksize = blocksize
self.default_fill_cache = default_fill_cache
self.default_cache_type = default_cache_type
self.max_concurrency = max_concurrency

if self.max_concurrency is None:
self._blob_client_semaphore = _nullcontext()
else:
self._blob_client_semaphore = asyncio.Semaphore(max_concurrency)

if (
self.credential is None
and self.account_key is None
Expand Down Expand Up @@ -519,6 +536,15 @@ def _get_kwargs_from_urls(urlpath):
out["account_name"] = account_name
return out

@contextlib.asynccontextmanager
async def _get_blob_client(self, container_name, path):
"""
Get a blob client, respecting `self.max_concurrency` if set.
"""
async with self._blob_client_semaphore:
async with self.service_client.get_blob_client(container_name, path) as bc:
yield bc

def _get_credential_from_service_principal(self):
"""
Create a Credential for authentication. This can include a TokenCredential
Expand Down Expand Up @@ -1366,9 +1392,7 @@ async def _isfile(self, path):
return False
else:
try:
async with self.service_client.get_blob_client(
container_name, path
) as bc:
async with self._get_blob_client(container_name, path) as bc:
props = await bc.get_blob_properties()
if props["metadata"]["is_directory"] == "false":
return True
Expand Down Expand Up @@ -1427,7 +1451,7 @@ async def _exists(self, path):
# Empty paths exist by definition
return True

async with self.service_client.get_blob_client(container_name, path) as bc:
async with self._get_blob_client(container_name, path) as bc:
if await bc.exists():
return True

Expand All @@ -1445,9 +1469,7 @@ async def _exists(self, path):
async def _pipe_file(self, path, value, overwrite=True, **kwargs):
"""Set the bytes of given file"""
container_name, path = self.split_path(path)
async with self.service_client.get_blob_client(
container=container_name, blob=path
) as bc:
async with self._get_blob_client(container_name, path) as bc:
result = await bc.upload_blob(
data=value, overwrite=overwrite, metadata={"is_directory": "false"}
)
Expand All @@ -1464,9 +1486,7 @@ async def _cat_file(self, path, start=None, end=None, **kwargs):
else:
length = None
container_name, path = self.split_path(path)
async with self.service_client.get_blob_client(
container=container_name, blob=path
) as bc:
async with self._get_blob_client(container_name, path) as bc:
try:
stream = await bc.download_blob(offset=start, length=length)
except ResourceNotFoundError as e:
Expand Down Expand Up @@ -1528,7 +1548,7 @@ async def _url(self, path, expires=3600, **kwargs):
expiry=datetime.utcnow() + timedelta(seconds=expires),
)

async with self.service_client.get_blob_client(container_name, blob) as bc:
async with self._get_blob_client(container_name, blob) as bc:
url = f"{bc.url}?{sas_token}"
return url

Expand Down Expand Up @@ -1603,9 +1623,7 @@ async def _put_file(
else:
try:
with open(lpath, "rb") as f1:
async with self.service_client.get_blob_client(
container_name, path
) as bc:
async with self._get_blob_client(container_name, path) as bc:
await bc.upload_blob(
f1,
overwrite=overwrite,
Expand Down Expand Up @@ -1659,7 +1677,7 @@ async def _get_file(
return
container_name, path = self.split_path(rpath, delimiter=delimiter)
try:
async with self.service_client.get_blob_client(
async with self._get_blob_client(
container_name, path.rstrip(delimiter)
) as bc:
with open(lpath, "wb") as my_blob:
Expand All @@ -1681,7 +1699,7 @@ def getxattr(self, path, attr):
async def _setxattrs(self, rpath, **kwargs):
container_name, path = self.split_path(rpath)
try:
async with self.service_client.get_blob_client(container_name, path) as bc:
async with self._get_blob_client(container_name, path) as bc:
await bc.set_blob_metadata(metadata=kwargs)
self.invalidate_cache(self._parent(rpath))
except Exception as e:
Expand Down
16 changes: 16 additions & 0 deletions adlfs/tests/test_spec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import datetime
import os
import tempfile
from unittest import mock

import dask.dataframe as dd
import numpy as np
Expand Down Expand Up @@ -1424,3 +1426,17 @@ def test_find_with_prefix(storage):
assert test_1s == [test_bucket_name + "/prefixes/test_1"] + [
test_bucket_name + f"/prefixes/test_{cursor}" for cursor in range(10, 20)
]


def test_max_concurrency(storage):
fs = AzureBlobFileSystem(
account_name=storage.account_name, connection_string=CONN_STR, max_concurrency=2
)

assert isinstance(fs._blob_client_semaphore, asyncio.Semaphore)

fs._blob_client_semaphore = mock.MagicMock(fs._blob_client_semaphore)
path = {f"/data/{i}": b"value" for i in range(10)}
fs.pipe(path)

assert fs._blob_client_semaphore.__aenter__.call_count == 10
22 changes: 21 additions & 1 deletion adlfs/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import contextlib
import sys


async def filter_blobs(blobs, target_path, delimiter="/"):
"""
Filters out blobs that do not come from target_path
Expand Down Expand Up @@ -45,9 +49,25 @@ async def close_container_client(file_obj):
await file_obj.container_client.close()


if sys.version_info < (3, 10):
# Python 3.10 added support for async to nullcontext
@contextlib.asynccontextmanager
async def _nullcontext(*args):
yield

else:
_nullcontext = contextlib.nullcontext


async def close_credential(file_obj):
"""
Implements asynchronous closure of credentials for
AzureBlobFile objects
"""
await file_obj.credential.close()
try:
if file_obj.credential is not None:
await file_obj.credential.close()
else:
pass
except AttributeError:
pass