Skip to content

Commit

Permalink
Merge branch 'main' into logfire
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored May 1, 2024
2 parents c16a63f + d825790 commit d847f00
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions src/marvin/utilities/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,47 @@ class ScopedContext:
```
"""

def __init__(self):
"""Initializes the ScopedContext with a default empty dictionary."""
def __init__(self, initial_value: dict = None):
"""Initializes the ScopedContext with an initial valuedictionary."""
self._context_storage = contextvars.ContextVar(
"scoped_context_storage", default={}
"scoped_context_storage", default=initial_value or {}
)

def get(self, key: str, default: Any = None) -> Any:
return self._context_storage.get().get(key, default)

def __getitem__(self, key: str) -> Any:
notfound = object()
result = self.get(key, default=notfound)
if result == notfound:
raise KeyError(key)
return result

def set(self, **kwargs: Any) -> None:
ctx = self._context_storage.get()
updated_ctx = {**ctx, **kwargs}
self._context_storage.set(updated_ctx)
token = self._context_storage.set(updated_ctx)
return token

@contextmanager
def __call__(self, **kwargs: Any) -> Generator[None, None, Any]:
current_context = self._context_storage.get().copy()
self.set(**kwargs)
current_context_copy = self._context_storage.get().copy()
token = self.set(**kwargs)
try:
yield
finally:
self._context_storage.set(current_context)
try:
self._context_storage.reset(token)
except ValueError as exc:
if "was created in a different context" in str(exc).lower():
# the only way we can reach this line is if the setup and
# teardown of this context are run in different frames or
# threads (which happens with pytest fixtures!), in which case
# the token is considered invalid. This catch serves as a
# "manual" reset of the context values
self._context_storage.set(current_context_copy)
else:
raise


ctx = ScopedContext()

0 comments on commit d847f00

Please sign in to comment.