Skip to content

Commit

Permalink
Merge pull request #8 from nicholasjng/shelf-context
Browse files Browse the repository at this point in the history
Unify serde interface by introducing `Shelf.Context`
  • Loading branch information
nicholasjng authored Jan 26, 2024
2 parents b928788 + b81d770 commit d5c2946
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 157 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
types_or: [ python, pyi ]
args: [--ignore-missing-imports, --scripts-are-modules]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.11
rev: v0.1.14
hooks:
- id: ruff
args: [ --fix ]
Expand Down
17 changes: 8 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,18 @@ class MyModel:
return 1.


def save_to_disk(model: MyModel, tmpdir: str) -> str:
def save_to_disk(model: MyModel, ctx: shelf.Context) -> None:
"""Dumps the model to the directory ``tmpdir`` using `pickle`."""
fname = os.path.join(tmpdir, "my-model.pkl")
with open(fname, "wb") as f:
pickle.dump(model, f)
return fname
fp = ctx.file("my-model.pkl", mode="wb")
pickle.dump(model, fp)


def load_from_disk(fname: str) -> MyModel:
def load_from_disk(ctx: shelf.Context) -> MyModel:
"""Reloads the previously pickled model."""
with open(fname, "rb") as f:
model: MyModel = pickle.load(f)
return model
fname, = ctx.filenames
fp = ctx.file(fname, mode="rb")
model: MyModel = pickle.load(fp)
return model


shelf.register_type(MyModel, save_to_disk, load_from_disk)
Expand Down
1 change: 1 addition & 0 deletions src/shelf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@

from .core import Shelf
from .registry import deregister_type, lookup, register_type
from .types import Context
167 changes: 62 additions & 105 deletions src/shelf/core.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,52 @@
from __future__ import annotations

import contextlib
import os
import shutil
import tempfile
from os import PathLike
from typing import Any, Literal, TypeVar
import weakref
from typing import Any, TypeVar

from fsspec import AbstractFileSystem, filesystem
from fsspec.utils import get_protocol
from fsspec.utils import stringify_path

import shelf.registry
from shelf.util import is_fully_qualified, with_trailing_sep
from shelf.types import CacheOptions, Context
from shelf.util import filesystem_from_uri, with_trailing_sep

T = TypeVar("T")


class Shelf:
def __init__(
self,
prefix: str | os.PathLike[str] = "",
cache_dir: str | PathLike[str] | None = None,
cache_type: Literal["blockcache", "filecache", "simplecache"] = "filecache",
fsconfig: dict[str, dict[str, Any]] | None = None,
configfile: str | PathLike[str] | None = None,
):
self.prefix = str(prefix)

self.cache_type = cache_type
self.cache_dir = cache_dir

# config object holding storage options for file systems
# TODO: Validate schema for inputs
if configfile and not fsconfig:
import yaml

with open(configfile, "r") as f:
self.fsconfig = yaml.safe_load(f)
else:
self.fsconfig = fsconfig or {}
def __init__(self, cache_options: CacheOptions | None = None):
self.cache_options = cache_options

_tempdir = tempfile.mkdtemp()
self._tempdir = _tempdir
weakref.finalize(self, self._cleanup_tempdir, _tempdir)

@property
def tempdir(self) -> str:
return self._tempdir

def get(self, rpath: str, expected_type: type[T]) -> T:
# load machinery early, so that we do not download
# if the type is not registered.
@staticmethod
def _cleanup_tempdir(tempdir: str) -> None:
# TODO: Use TemporaryDirectory's builtin finalizer?
shutil.rmtree(tempdir, ignore_errors=True)

def get(
self,
rpath: str | os.PathLike[str],
expected_type: type[T],
storage_options: dict[str, Any] | None = None,
download_options: dict[str, Any] | None = None,
) -> T:
# load machinery early, so we don't download if the type is not registered.
serde = shelf.registry.lookup(expected_type)

if not is_fully_qualified(rpath):
rpath = os.path.join(self.prefix, rpath)

protocol = get_protocol(rpath)
# file system-specific options.
config = self.fsconfig.get(protocol, {})
storage_options = config.get("storage", {})

if self.cache_dir is not None:
proto = self.cache_type
kwargs = {
"target_protocol": protocol,
"target_options": storage_options,
"cache_storage": self.cache_dir,
}
else:
proto = protocol
kwargs = storage_options
rpath = stringify_path(rpath)

fs: AbstractFileSystem = filesystem(proto, **kwargs)
fs = filesystem_from_uri(rpath, self.cache_options, storage_options)

rfiles: list[str]
try:
rfiles = fs.ls(rpath, detail=False)
# some file systems (e.g. local) don't allow filenames in `ls`
Expand All @@ -74,68 +56,43 @@ def get(self, rpath: str, expected_type: type[T]) -> T:
if not rfiles:
raise FileNotFoundError(rpath)

with contextlib.ExitStack() as stack:
tmpdir = stack.enter_context(tempfile.TemporaryDirectory())
# TODO: Push a unique directory (e.g. checksum) in front to
# create a directory

# explicit file lists have the side effect that remote subdirectory structures
# are flattened.
lfiles = [os.path.join(tmpdir, os.path.basename(f)) for f in rfiles]
# explicit file lists have the side effect that remote subdirectory structures
# are flattened.
lfiles = [os.path.join(self.tempdir, os.path.basename(f)) for f in rfiles]
fs.get(rfiles, lfiles, **(download_options or {}))

download_options = config.get("download", {})
fs.get(rfiles, lfiles, **download_options)

# TODO: Support deserializer interfaces taking unraveled tuples, i.e. filenames
# as arguments in the multifile case
lpath: str | tuple[str, ...]
if len(lfiles) == 1:
lpath = lfiles[0]
else:
lpath = tuple(lfiles)

obj: T = serde.deserializer(lpath)
# TODO: For more secure access, only allow lfiles as file descriptors
with Context(self.tempdir, filenames=lfiles) as ctx:
obj = serde.deserializer(ctx)

return obj

def put(self, obj: T, rpath: str) -> None:
# load machinery early, so that we do not download
# if the type is not registered.
serde = shelf.registry.lookup(type(obj))

if not is_fully_qualified(rpath):
rpath = os.path.join(self.prefix, rpath)

protocol = get_protocol(rpath)

# file system-specific options.
fsconfig = self.fsconfig.get(protocol, {})
storage_options = fsconfig.get("storage", {})
def put(
self,
obj: T,
rpath: str | os.PathLike[str],
storage_options: dict[str, Any] | None = None,
upload_options: dict[str, Any] | None = None,
) -> None:
# load machinery early, so we don't download if the type is not registered.
objtype: type[T] = type(obj)

if self.cache_dir is not None:
proto = self.cache_type
kwargs = {
"target_protocol": protocol,
"target_options": storage_options,
"cache_storage": self.cache_dir,
}
else:
proto = protocol
kwargs = storage_options
serde = shelf.registry.lookup(objtype)

fs: AbstractFileSystem = filesystem(proto, **kwargs)
rpath = stringify_path(rpath)

with contextlib.ExitStack() as stack:
tmpdir = stack.enter_context(tempfile.TemporaryDirectory())
lpath = serde.serializer(obj, tmpdir)
fs = filesystem_from_uri(rpath, self.cache_options, storage_options)

recursive = isinstance(lpath, (list, tuple))
if recursive:
# signals fsspec to put all files into rpath directory
rpath = with_trailing_sep(rpath)
with Context(self.tempdir) as ctx:
serde.serializer(obj, ctx)
lpaths = ctx.filenames

upload_options = fsconfig.get("upload", {})
# TODO: Construct explicit lists always to hit the fast path of fs.put()
fs.put(lpath, rpath, recursive=recursive, **upload_options)
rpaths: str | list[str]
if len(lpaths) > 1:
# signals fsspec to put all files into rpath directory
# TODO: Construct list always to hit the fast path of fs.put()
rpaths = with_trailing_sep(rpath)
else:
rpaths = [rpath]

return fs.info(rpath)
fs.put(lpaths, rpaths, **(upload_options or {}))
18 changes: 10 additions & 8 deletions src/shelf/registry.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,39 @@
from __future__ import annotations

import types
from typing import Callable, NamedTuple
from typing import Callable, TypeVar

from shelf.types import Context, IOPair

class IO(NamedTuple):
serializer: Callable
deserializer: Callable
T = TypeVar("T")


# internal, mutable
_registry: dict[type, IO] = {}
_registry: dict[type, IOPair] = {}

# external, immutable
registry = types.MappingProxyType(_registry)


def register_type(
t: type, serializer: Callable, deserializer: Callable, clobber: bool = False
t: type[T],
serializer: Callable[[T, Context], None],
deserializer: Callable[[Context], T],
clobber: bool = False,
) -> None:
"""Register serializer and deserializer for a given type t."""
if t in _registry and not clobber:
raise RuntimeError(f"type {t} is already registered, rerun with clobber=True to override")

_registry[t] = IO(serializer, deserializer)
_registry[t] = IOPair(serializer, deserializer)


def deregister_type(t: type) -> None:
"""Remove a type's serializer and deserializer from the type registry."""
_registry.pop(t, None)


def lookup(t: type, strict: bool = True, bound: type | None = None) -> IO:
def lookup(t: type[T], strict: bool = True, bound: type | None = None) -> IOPair[T]:
"""
Returns a type's registered serialization/deserialization functions.
Expand Down
61 changes: 61 additions & 0 deletions src/shelf/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

from collections import deque
from dataclasses import dataclass
from os import PathLike
from pathlib import Path
from typing import IO, Any, Callable, Generic, Literal, TypeVar

T = TypeVar("T")


# TODO: Consider splitting these into ReadContext and WriteContext
class Context:
def __init__(self, tmpdir: str | PathLike[str], filenames: list[str] | None = None):
self._tmpdir = Path(tmpdir)
self._fds: deque[IO] = deque()
self._filenames: list[str] = filenames or []

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
# Unwind the file queue by closing
while self._fds:
file = self._fds.pop()
file.close()

@property
def files(self):
return self._fds

@property
def filenames(self):
return self._filenames

@property
def tmpdir(self):
return self._tmpdir

def directory(self, name: str) -> PathLike[str]:
return self.tmpdir / name

def file(self, name: str, **openkwargs: Any) -> IO:
# TODO: Assert name is not absolute
fp = self.tmpdir / name
desc = open(fp, **openkwargs)
self._fds.append(desc)
self._filenames.append(str(fp))
return desc


@dataclass(frozen=True)
class IOPair(Generic[T]):
serializer: Callable[[T, Context], None]
deserializer: Callable[[Context], T]


@dataclass(frozen=True)
class CacheOptions:
directory: PathLike[str]
type: Literal["blockcache", "filecache", "simplecache"] = "filecache"
32 changes: 30 additions & 2 deletions src/shelf/util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,41 @@
from __future__ import annotations

import os
from typing import Any

from fsspec import AbstractFileSystem, filesystem
from fsspec.utils import get_protocol, stringify_path

from shelf.types import CacheOptions


def filesystem_from_uri(
uri: str,
cache_options: CacheOptions | None = None,
storage_options: dict[str, Any] | None = None,
) -> AbstractFileSystem:
protocol = get_protocol(uri)
storage_options = storage_options or {}
if cache_options is not None:
protocol = cache_options.type
kwargs = {
"target_protocol": cache_options.type,
"target_options": storage_options,
"cache_storage": cache_options.directory,
}
else:
kwargs = storage_options

def is_fully_qualified(path: str) -> bool:
fs: AbstractFileSystem = filesystem(protocol, **kwargs)
return fs


def is_fully_qualified(path: str | os.PathLike[str]) -> bool:
path = stringify_path(path)
protocol = get_protocol(path)
return any(path.startswith(protocol + sep) for sep in ("::", "://"))


def with_trailing_sep(path: str) -> str:
def with_trailing_sep(path: str | os.PathLike[str]) -> str:
path = stringify_path(path)
return path if path.endswith(os.sep) else path + os.sep
Loading

0 comments on commit d5c2946

Please sign in to comment.