Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: add extension_kwargs kwargs and nitpicky stuff #3

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ repos:
rev: v1.13.0
hooks:
- id: mypy
additional_dependencies:
- cython
- setuptools
38 changes: 26 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
[build-system]
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"
requires = ["setuptools", "wheel"]

[project]
name = "witty"
description = "Well-in-Time Compiler for Cython Modules"
readme = "README.md"
requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
]
classifiers = ["Programming Language :: Python :: 3"]
keywords = []
license = { text = "BSD 3-Clause License" }
authors = [
{ email = "[email protected]", name = "Jan Funke" },
{ email = "[email protected]", name = "Jan Funke" },
{ email = "[email protected]", name = "Talley Lambert" },
]
dynamic = ["version"]
dependencies = ["cython", "setuptools; python_version >= '3.12'"]

[project.optional-dependencies]
dev = [
'pytest',
'ruff',
'mypy',
'pdoc',
'pre-commit'
]
dev = ['pytest', 'ruff', 'mypy', 'pdoc', 'pre-commit']

[project.urls]
homepage = "https://github.com/funkelab/witty"
Expand All @@ -34,3 +27,24 @@ repository = "https://github.com/funkelab/witty"
[tool.ruff]
target-version = "py39"
src = ["src"]

[tool.ruff.lint]
select = [
"E", # style errors
"F", # flakes
"W", # warnings
"I", # isort
"UP", # pyupgrade
]

[tool.mypy]
files = "src/**/*.py"
strict = true
disallow_any_generics = false
disallow_subclassing_any = false
show_error_codes = true
pretty = true

[[tool.mypy.overrides]]
module = ["tests.*"]
disallow_untyped_defs = false
10 changes: 8 additions & 2 deletions src/witty/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from .compile_module import compile_module
from importlib.metadata import PackageNotFoundError, version

try:
__version__ = version("witty")
except PackageNotFoundError:
# package is not installed
__version__ = "unknown"

__version__ = "0.1"
from .compile_module import compile_module

__all__ = ["compile_module"]
255 changes: 153 additions & 102 deletions src/witty/compile_module.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,86 @@
import os
import Cython
from __future__ import annotations

import hashlib
import importlib.util
import json
import os
import sys
from Cython.Build import cythonize
from Cython.Build.Inline import to_unicode, _get_build_extension
from Cython.Utils import get_cython_cache_dir
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

try:
from distutils.core import Extension
except ImportError:
from setuptools import Extension # type: ignore [no-redef]

import Cython
from Cython.Build import cythonize
from Cython.Build.Inline import build_ext
from Cython.Utils import get_cython_cache_dir
from setuptools import Distribution, Extension

def load_dynamic(module_name, module_lib):
spec = importlib.util.spec_from_file_location(module_name, module_lib)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return sys.modules[module_name]
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from types import ModuleType


def compile_module(
source_pyx,
source_files=None,
include_dirs=None,
library_dirs=None,
language="c",
extra_compile_args=None,
extra_link_args=None,
name=None,
force_rebuild=False,
quiet=False,
):
source_pyx: str,
*,
source_files: Sequence[Path | str] = (),
include_dirs: Sequence[Path | str] = (".",),
library_dirs: Sequence[Path | str] = (),
language: Literal["c", "c++"] | None = None,
extra_compile_args: list[str] | None = None,
extra_link_args: list[str] | None = None,
name: str = "_witty_module",
force_rebuild: bool = False,
quiet: bool = False,
**extension_kwargs: Any,
) -> ModuleType:
"""Compile a Cython module given as a PYX source string.

The module will be stored in Cython's cache directory. Called with the same
``source_pyx``, the cached module will be returned.

Args:

source_pyx (``str``):

The PYX source code.

source_files (list of ``Path``s, optional):

Additional source files the PYX code depends on. Changes to those
files will trigger re-compilation of the module.

include_dirs (list of ``Path``s, optional):
library_dirs (list of ``Path``s, optional):
language (``str``, optional):
extra_compile_args (list of ``str``, optional):
extra_link_args (list of ``str``, optional):

Arguments to forward to the Cython extension.

name (``str``, optional):

The base-name of the module file. Defaults to ``_witty_module``.

force_rebuild (``bool``, optional):

Force a rebuild even if a module with that name/hash already
exists.

quiet (``bool``, optional):

Supress output except errors and warnings.

Returns:

The module will be stored in
[Cython's cache directory](https://cython.readthedocs.io/en/latest/src/userguide/source_files_and_compilation.html#cython-cache).
Called with the same `source_pyx`, the cached module will be returned.

Parameters
----------
source_pyx : str
The PYX source code.
source_files : list of Path, optional
Additional source files the PYX code depends on. Changes to these
files will trigger re-compilation of the module.
include_dirs : list of Path, optional
List of directories to search for C/C++ header files (in Unix
form for portability).
library_dirs : list of Path, optional
List of directories to search for C/C++ libraries at link time.
language : str, optional
Extension language (i.e., "c", "c++", "objc"). Will be detected
from the source extensions if not provided.
extra_compile_args : list of str, optional
Extra platform- and compiler-specific information to use when
compiling the source files in 'sources'. This is typically a
list of command-line arguments for platforms and compilers where
"command line" makes sense.
extra_link_args : list of str, optional
Extra platform- and compiler-specific information to use when
linking object files to create the extension (or a new static
Python interpreter). Has a similar interpretation as for 'extra_compile_args'.
name : str, optional
The base name of the module file. Defaults to "_witty_module".
force_rebuild : bool, optional
Force a rebuild even if a module with the same name/hash already exists.
quiet : bool, optional
Suppress output except for errors and warnings.
extension_kwargs : dict, optional
Additional keyword arguments passed to the distutils `Extension` constructor.

Returns
-------
ModuleType
The compiled module.
"""

if source_files is None:
source_files = []
if include_dirs is None:
include_dirs = ["."]
if library_dirs is None:
library_dirs = []
if name is None:
name = "_witty_module"

source_pyx = to_unicode(source_pyx)
sources = [source_pyx]

for source_file in source_files:
sources.append(open(source_file, "r").read())

source_hashes = [
hashlib.md5(source.encode("utf-8")).hexdigest() for source in sources
]
source_key = (source_hashes, sys.version_info, sys.executable, Cython.__version__)
module_hash = hashlib.md5(str(source_key).encode("utf-8")).hexdigest()
module_hash = _generate_hash(
source_pyx, source_files, extra_compile_args, extra_link_args, extension_kwargs
)
module_name = name + "_" + module_hash

# already loaded?
Expand All @@ -107,61 +92,127 @@ def compile_module(
module_dir = Path(get_cython_cache_dir()) / "witty"
module_pyx = (module_dir / module_name).with_suffix(".pyx")
module_lib = (module_dir / module_name).with_suffix(module_ext)
module_lock = (module_dir / module_name).with_suffix(".lock")

if not quiet:
print(f"Compiling {module_name} into {module_lib}...")

module_dir.mkdir(parents=True, exist_ok=True)

# make sure the same module is not build concurrently
with open(module_lock, "w") as lock_f:
lock_file(lock_f)

with _module_locked(module_pyx):
# already compiled?
if module_lib.is_file() and not force_rebuild:
if not quiet:
print(f"Reusing already compiled module from {module_lib}")
return load_dynamic(module_name, module_lib)
return _load_dynamic(module_name, module_lib)

# create pyx file
with open(module_pyx, "w") as f:
f.write(source_pyx)
module_pyx.write_text(source_pyx)

extension = Extension(
module_name,
sources=[str(module_pyx)],
include_dirs=include_dirs,
library_dirs=library_dirs,
include_dirs=[str(x) for x in include_dirs],
library_dirs=[str(x) for x in library_dirs],
language=language,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
**(extension_kwargs or {}),
)

build_extension.extensions = cythonize(
[extension], compiler_directives={"language_level": "3"}, quiet=quiet
[extension],
compiler_directives={"language_level": "3"},
quiet=quiet,
)
build_extension.build_temp = str(module_dir)
build_extension.build_lib = str(module_dir)
build_extension.run()

return load_dynamic(module_name, module_lib)
return _load_dynamic(module_name, module_lib)


def _load_dynamic(module_name: str, module_path: Path) -> ModuleType:
"""Dynamically load a module from a path."""
spec = importlib.util.spec_from_file_location(module_name, module_path)
if not spec or not spec.loader:
raise ImportError(f"Failed to load module {module_name} from {module_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return sys.modules[module_name]


def _generate_hash(
source_pyx: str, source_files: Sequence[Path | str] = (), *args: dict | list | None
) -> str:
"""Generate a hash key for a `source_pyx` along with other source file paths."""
sources = [source_pyx] + [Path(source).read_text() for source in source_files]
src_hashes = [hashlib.md5(source.encode("utf-8")).hexdigest() for source in sources]
arg_hash = _hash_args(args)
source_key = (
src_hashes,
arg_hash,
sys.version_info,
sys.executable,
Cython.__version__,
)
return hashlib.md5(str(source_key).encode("utf-8")).hexdigest()


def _hash_args(containers: tuple[dict | list | None, ...]) -> str:
"""Hash a bunch of mutable arg container objects in a reproducible way.

This is for stuff like extra_compile_args, extra_link_args, and extension_kwargs.
"""
hash_obj = hashlib.md5()
for container in containers:
# sort dict keys for reproducibility
serialized = json.dumps(container, sort_keys=True)
# Update the hash object with the serialized container
hash_obj.update(serialized.encode())
return hash_obj.hexdigest()


@contextmanager
def _module_locked(module_path: Path) -> Iterator[None]:
"""Temporarily lock a module file to prevent concurrent compilation."""
module_lock_file = module_path.with_suffix(".lock")
with open(module_lock_file, "w") as lock_fd:
_lock_file(lock_fd)
try:
yield
finally:
_unlock_file(lock_fd)


def _get_build_extension() -> build_ext:
# same as `cythonize` Build.Inline._get_build_extension
# vendored to avoid using a private API
dist = Distribution()
# Ensure the build respects distutils configuration by parsing
# the configuration files
config_files = dist.find_config_files()
dist.parse_config_files(config_files)
build_extension = build_ext(dist)
build_extension.finalize_options()
return build_extension


if os.name == "nt":
import msvcrt

def lock_file(file):
msvcrt.locking(file.fileno(), msvcrt.LK_LOCK, os.path.getsize(file.name))
def _lock_file(file: Any) -> None:
msvcrt.locking(file.fileno(), msvcrt.LK_LOCK, os.path.getsize(file.name)) # type: ignore

def unlock_file(file):
msvcrt.locking(file.fileno(), msvcrt.LK_UNLCK, os.path.getsize(file.name))
def _unlock_file(file: Any) -> None:
msvcrt.locking(file.fileno(), msvcrt.LK_UNLCK, os.path.getsize(file.name)) # type: ignore

else:
import fcntl

def lock_file(file):
def _lock_file(file: Any) -> None:
fcntl.lockf(file, fcntl.LOCK_EX)

def unlock_file(file):
def _unlock_file(file: Any) -> None:
fcntl.lockf(file, fcntl.LOCK_UN)
Empty file added tests/__init__.py
Empty file.
Loading