Skip to content

Commit

Permalink
Merge pull request #4 from nicholasjng/github-actions
Browse files Browse the repository at this point in the history
Add GitHub Actions lint+test setup
  • Loading branch information
nicholasjng authored Dec 17, 2023
2 parents 46d4c7d + c312c48 commit f08f3cf
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 15 deletions.
63 changes: 63 additions & 0 deletions .github/workflows/lint-and-test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
name: Lint and test shelf

on:
push:
branches:
- master
pull_request:
branches:
- master

jobs:
lint:
runs-on: ubuntu-latest
env:
MYPY_CACHE_DIR: "${{ github.workspace }}/.cache/mypy"
RUFF_CACHE_DIR: "${{ github.workspace }}/.cache/ruff"
PRE_COMMIT_HOME: "${{ github.workspace }}/.cache/pre-commit"
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python and dependencies
uses: actions/setup-python@v4
with:
python-version: 3.11
cache: pip
cache-dependency-path: |
requirements-dev.txt
pyproject.toml
- name: Install dependencies
run: |
pip install -r requirements-dev.txt
pip install -e . --no-deps
- name: Cache pre-commit tools
uses: actions/cache@v3
with:
path: |
${{ env.MYPY_CACHE_DIR }}
${{ env.RUFF_CACHE_DIR }}
${{ env.PRE_COMMIT_HOME }}
key: ${{ hashFiles('requirements-dev.txt', '.pre-commit-config.yaml') }}-linter-cache
- name: Run pre-commit checks
run: pre-commit run --all-files --verbose --show-diff-on-failure
test:
name: Test shelf on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest, macos-latest, windows-latest ]
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up oldest supported Python on ${{ matrix.os }}
uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Run tests on oldest supported Python
run: |
python -m pip install --upgrade pip
python -m pip install ".[dev]"
python -m pytest
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Here's how you register a custom neural network type that uses [pickle](https://
import numpy as np
import pickle
import shelf
import os


class MyModel:
Expand All @@ -28,9 +29,9 @@ class MyModel:
return 1.


def save_to_disk(model: MyModel) -> str:
"""Dumps the model to disk using `pickle`."""
fname = "my-model.pkl"
def save_to_disk(model: MyModel, tmpdir: str) -> str:
"""Dumps the model to the directory ``tmpdir`` using `pickle`."""
fname = os.path.join(tmpdir, "my-model.pkl")
with open(fname, "wb") as f:
pickle.dump(model, f)
return fname
Expand Down
2 changes: 2 additions & 0 deletions src/shelf/registry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import types
from typing import Callable, NamedTuple

Expand Down
21 changes: 11 additions & 10 deletions src/shelf/shelf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import contextlib
import os
import tempfile
Expand All @@ -6,7 +8,7 @@
from typing import Any, Literal, TypeVar

from fsspec import AbstractFileSystem, filesystem
from fsspec.utils import get_protocol
from fsspec.utils import get_protocol, stringify_path

import shelf.registry as registry
from shelf.util import is_fully_qualified
Expand Down Expand Up @@ -55,7 +57,7 @@ def __init__(

def get(self, rpath: str, expected_type: type[T]) -> T:
if not is_fully_qualified(rpath):
rpath = self.prefix + rpath
rpath = os.path.join(self.prefix, rpath)

# load machinery early, so that we do not download
# if the type is not registered.
Expand Down Expand Up @@ -83,15 +85,16 @@ def get(self, rpath: str, expected_type: type[T]) -> T:

with contextlib.ExitStack() as stack:
tmpdir = stack.enter_context(tempfile.TemporaryDirectory())
stack.enter_context(contextlib.chdir(tmpdir))

# trailing slash tells fsspec to download files into `lpath`
lpath = tmpdir + os.sep
lpath = stringify_path(tmpdir.rstrip(os.sep) + os.sep)
fs.get(rpath, lpath, **download_options)

# TODO: Find a way to pass files in expected order
filenames = [p.name for p in Path(tmpdir).iterdir() if p.is_file()]
obj: T = serde.deserializer(*filenames)
files = [str(p) for p in Path(tmpdir).iterdir() if p.is_file()]
if not files:
raise ValueError(f"no files found for rpath {rpath!r}")
obj: T = serde.deserializer(*files)

return obj

Expand All @@ -101,7 +104,7 @@ def put(self, obj: T, rpath: str) -> None:
serde = registry.lookup(type(obj))

if not is_fully_qualified(rpath):
rpath = self.prefix + rpath
rpath = os.path.join(self.prefix, rpath)

protocol = get_protocol(rpath)

Expand All @@ -124,10 +127,8 @@ def put(self, obj: T, rpath: str) -> None:

with contextlib.ExitStack() as stack:
tmpdir = stack.enter_context(tempfile.TemporaryDirectory())
# chdir into the temporary to be able to work with filenames only
stack.enter_context(contextlib.chdir(tmpdir))
# TODO: What about multiple lpaths?
lpath = serde.serializer(obj)
lpath = serde.serializer(obj, tmpdir)

upload_options = fsconfig.get("upload", {})
fs.put(lpath, rpath, **upload_options)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_shelf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
from pathlib import Path

import shelf
Expand All @@ -7,8 +8,8 @@
def test_json_roundtrip(tmp_path: Path) -> None:
"""Test a simple data artifact JSON roundtrip."""

def json_dump(d: dict) -> str:
fname = "dump.json"
def json_dump(d: dict, tmpdir: str) -> str:
fname = os.path.join(tmpdir, "dump.json")
with open(fname, "w") as f:
json.dump(d, f)
return fname
Expand Down

0 comments on commit f08f3cf

Please sign in to comment.