Skip to content

Commit

Permalink
Simplify code leveraging the fact that ContextVar is automatically th…
Browse files Browse the repository at this point in the history
…read-local
  • Loading branch information
spanezz committed Oct 8, 2024
1 parent 16d4ae2 commit 7b3ce47
Showing 1 changed file with 21 additions and 77 deletions.
98 changes: 21 additions & 77 deletions asgiref/local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import asyncio
import contextlib
import contextvars
import threading
from typing import Any, Union
Expand All @@ -9,36 +7,40 @@ class _CVar:
"""Storage utility for Local."""

def __init__(self) -> None:
self._thread_lock = threading.RLock()
self._data: dict[str, contextvars.ContextVar[Any]] = {}

def __getattr__(self, key: str) -> Any:
try:
var = self._data[key]
except KeyError:
raise AttributeError(f"{self!r} object has no attribute {key!r}")
with self._thread_lock:
try:
var = self._data[key]
except KeyError:
raise AttributeError(f"{self!r} object has no attribute {key!r}")

try:
return var.get()
except LookupError:
raise AttributeError(f"{self!r} object has no attribute {key!r}")

def __setattr__(self, key: str, value: Any) -> None:
if key == "_data":
if key in ("_data", "_thread_lock"):
return super().__setattr__(key, value)

var = self._data.get(key)
if var is None:
self._data[key] = var = contextvars.ContextVar(key)
with self._thread_lock:
var = self._data.get(key)
if var is None:
self._data[key] = var = contextvars.ContextVar(key)
var.set(value)

def __delattr__(self, key: str) -> None:
if key in self._data:
del self._data[key]
else:
raise AttributeError(f"{self!r} object has no attribute {key!r}")
with self._thread_lock:
if key in self._data:
del self._data[key]
else:
raise AttributeError(f"{self!r} object has no attribute {key!r}")


class Local:
def Local(thread_critical: bool = False) -> Union[threading.local, _CVar]:
"""Local storage for async tasks.
This is a namespace object (similar to `threading.local`) where data is
Expand All @@ -65,65 +67,7 @@ class Local:
Unlike plain `contextvars` objects, this utility is threadsafe.
"""

def __init__(self, thread_critical: bool = False) -> None:
self._thread_critical = thread_critical
self._thread_lock = threading.RLock()

self._storage: "Union[threading.local, _CVar]"

if thread_critical:
# Thread-local storage
self._storage = threading.local()
else:
# Contextvar storage
self._storage = _CVar()

@contextlib.contextmanager
def _lock_storage(self):
# Thread safe access to storage
if self._thread_critical:
try:
# this is a test for are we in a async or sync
# thread - will raise RuntimeError if there is
# no current loop
asyncio.get_running_loop()
except RuntimeError:
# We are in a sync thread, the storage is
# just the plain thread local (i.e, "global within
# this thread" - it doesn't matter where you are
# in a call stack you see the same storage)
yield self._storage
else:
# We are in an async thread - storage is still
# local to this thread, but additionally should
# behave like a context var (is only visible with
# the same async call stack)

# Ensure context exists in the current thread
if not hasattr(self._storage, "cvar"):
self._storage.cvar = _CVar()

# self._storage is a thread local, so the members
# can't be accessed in another thread (we don't
# need any locks)
yield self._storage.cvar
else:
# Lock for thread_critical=False as other threads
# can access the exact same storage object
with self._thread_lock:
yield self._storage

def __getattr__(self, key):
with self._lock_storage() as storage:
return getattr(storage, key)

def __setattr__(self, key, value):
if key in ("_local", "_storage", "_thread_critical", "_thread_lock"):
return super().__setattr__(key, value)
with self._lock_storage() as storage:
setattr(storage, key, value)

def __delattr__(self, key):
with self._lock_storage() as storage:
delattr(storage, key)
if thread_critical:
return threading.local()
else:
return _CVar()

0 comments on commit 7b3ce47

Please sign in to comment.