diff --git a/README.md b/README.md
index 47e176d..0885b66 100644
--- a/README.md
+++ b/README.md
@@ -5,6 +5,7 @@
![Language](https://img.shields.io/badge/Language-Cython-FEDF5B)
![Python Implementation](https://img.shields.io/pypi/implementation/spatium?label=Implementation)
+[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Codegen](https://github.com/shBLOCK/spatium/actions/workflows/codegen.yml/badge.svg)](https://github.com/shBLOCK/spatium/actions/workflows/codegen.yml)
[![Tests](https://github.com/shBLOCK/spatium/actions/workflows/tests.yml/badge.svg)](https://github.com/shBLOCK/spatium/actions/workflows/tests.yml)
diff --git a/benchmark/benchmarking.py b/benchmark/benchmarking.py
index bb053c6..6c57884 100644
--- a/benchmark/benchmarking.py
+++ b/benchmark/benchmarking.py
@@ -3,12 +3,23 @@
import json
import math
import time
+
# noinspection PyPep8Naming
from datetime import datetime as DateTime
import itertools
from pathlib import Path
from types import FunctionType
-from typing import Self, Callable, Dict, overload, Iterable, Sequence, Optional, Generator, Any
+from typing import (
+ Self,
+ Callable,
+ Dict,
+ overload,
+ Iterable,
+ Sequence,
+ Optional,
+ Generator,
+ Any,
+)
import colorama
@@ -30,7 +41,7 @@
"log",
"temp_log",
"indent_log",
- "CI"
+ "CI",
)
TIMER = time.perf_counter_ns
@@ -38,6 +49,7 @@
AUTO_NUMBER_TARGET_TIME = 0.5e9
# AUTO_NUMBER_TARGET_TIME = 0.005e9
+
class Subject:
"""
Define a benchmark subject. (e.g. spatium, Numpy, ...)
@@ -45,10 +57,13 @@ class Subject:
the decorated function will become the common setup procedure of this subject.
(e.g. `from spatium import Vec3`)
"""
+
instances: list["Subject"] = []
get_instance = classmethod(_instance_from_id)
- def __init__(self, name: str, *, color: str = None, identifier: str = None, sort: int = None):
+ def __init__(
+ self, name: str, *, color: str = None, identifier: str = None, sort: int = None
+ ):
self.setup: SourceLines = None
self.id: str = identifier
self.name = name
@@ -59,8 +74,7 @@ def __init__(self, name: str, *, color: str = None, identifier: str = None, sort
def __call__(self, setup: FunctionType) -> Self:
self.id = setup.__name__
self.setup = extract_and_validate_source(
- setup,
- f""
+ setup, f""
)
return self
@@ -73,7 +87,7 @@ def to_json(self) -> Dict:
"name": self.name,
"color": self.color,
"setup": self.setup,
- "sort": self.sort
+ "sort": self.sort,
}
@classmethod
@@ -82,16 +96,18 @@ def from_json(cls, data: Dict) -> Self:
name=data["name"],
color=data["color"],
identifier=data["id"],
- sort=data["sort"]
+ sort=data["sort"],
)
inst.setup = tuple(data["setup"])
return inst
+
class Benchmark:
"""
Represents a benchmark, can have multiple test cases for multiple subject.
Best be instantiated in a class as a class property, so that id can be automatically assigned.
"""
+
instances: list["Benchmark"] = []
get_instance = classmethod(_instance_from_id)
@@ -103,13 +119,16 @@ def __init__(self, name: str, identifier: str = None):
if self.id is None:
# infer id from call site
import re
+
line = inspect.stack()[1].code_context[0]
pattern = r"\s*(?P\w+)\s*=\s*Benchmark\s*\("
mat = re.match(pattern, line)
if mat is not None:
self.id = mat.group("name")
else:
- raise SyntaxError(f"Benchmark id infer failed: only instantiation source line matching the regex \"{pattern}\" supports inferring.")
+ raise SyntaxError(
+ f'Benchmark id infer failed: only instantiation source line matching the regex "{pattern}" supports inferring.'
+ )
Benchmark.instances.append(self)
@@ -127,10 +146,12 @@ def _get_or_create_case(self, subject: Subject) -> "TestCase":
def setup(self, subject: Subject) -> Callable[[FunctionType], None]:
"""Function decorator. Set the setup routine of a test case."""
self._check()
+
def inner(func: FunctionType):
case = self._get_or_create_case(subject)
case.setup_src = extract_source(func)
case.update()
+
return inner
@overload
@@ -144,8 +165,11 @@ def __call__(self, func: FunctionType) -> Subject:
The setup routine must be set first if needed!
"""
+
@overload
- def __call__(self, *subjects: Subject, number: int = None) -> Callable[[FunctionType], None]:
+ def __call__(
+ self, *subjects: Subject, number: int = None
+ ) -> Callable[[FunctionType], None]:
"""
Decorate a function to add test cases for multiple subjects.
The decorator returns None.
@@ -175,8 +199,10 @@ def inner(i_func: FunctionType):
# is "@benchmark_name()" (not "@benchmark_name")
if func is None:
if not all(c == "_" for c in i_func.__name__):
- raise NameError("Function name must only consist of underscores "
- "if subjects are specified via arguments.")
+ raise NameError(
+ "Function name must only consist of underscores "
+ "if subjects are specified via arguments."
+ )
for subject in subjects:
case = self._get_or_create_case(subject)
@@ -191,11 +217,13 @@ def inner(i_func: FunctionType):
return subjects[0]
return inner
- def run(self, order_permutations: bool = False, min_runs_per_case: int = 100) -> list["TestCaseResult"]:
+ def run(
+ self, order_permutations: bool = False, min_runs_per_case: int = 100
+ ) -> list["TestCaseResult"]:
log(f"Benchmark - {self.id}:")
with indent_log():
all_results = []
- results_by_subject = {s:[] for s in self.testcases.keys()}
+ results_by_subject = {s: [] for s in self.testcases.keys()}
permutations: Sequence[Sequence[TestCase]]
if order_permutations:
# noinspection PyTypeChecker
@@ -229,8 +257,11 @@ def run(self, order_permutations: bool = False, min_runs_per_case: int = 100) ->
with NoGC:
for rep in range(repeats):
with temp_log():
- log(f"Sequence[{(i * repeats + rep + 1)}"
- f"/{len(permutations) * repeats}]: ", False)
+ log(
+ f"Sequence[{(i * repeats + rep + 1)}"
+ f"/{len(permutations) * repeats}]: ",
+ False,
+ )
for case in cases:
result = case.run()
result.sequence = cases
@@ -242,10 +273,12 @@ def run(self, order_permutations: bool = False, min_runs_per_case: int = 100) ->
with indent_log():
for subject, results in results_by_subject.items():
times = tuple(r.runtime for r in results)
- log(f"{subject.id}({len(results)} runs): "
+ log(
+ f"{subject.id}({len(results)} runs): "
f"Min {min(times)/1e9:.3f}s, "
f"Avg {sum(times)/len(times)/1e9:.3f}s, "
- f"Max {max(times)/1e9:.3f}s")
+ f"Max {max(times)/1e9:.3f}s"
+ )
log()
@@ -258,15 +291,12 @@ def to_json(self) -> Dict:
return {
"id": self.id,
"name": self.name,
- "testcases": tuple(c.to_json() for c in self.testcases.values())
+ "testcases": tuple(c.to_json() for c in self.testcases.values()),
}
@classmethod
def from_json(cls, data: Dict) -> Self:
- inst = cls(
- name=data["name"],
- identifier=data["id"]
- )
+ inst = cls(name=data["name"], identifier=data["id"])
for case_dat in data["testcases"]:
case = TestCase.from_json(case_dat)
inst.testcases[case.subject] = case
@@ -294,14 +324,18 @@ def _generate_benchmark_func(self):
"for _i in _it:",
))
src.extend(indent(self.main_src))
- src.extend((
- "return _timer() - _begin",
- ))
+ src.extend(("return _timer() - _begin",))
src[1:] = indent(tuple(src[1:]))
try:
ls = {}
- exec(compile_src(src, f""), {}, ls)
+ exec(
+ compile_src(
+ src, f""
+ ),
+ {},
+ ls,
+ )
self.benchmark_func = ls["_benchmark"]
except Exception as e:
raise RuntimeError("Benchmark function generation failed.") from e
@@ -315,15 +349,11 @@ def update(self):
name = f""
if self.main_src is not None:
validate_source(
- self.subject.setup + self.setup_src + self.main_src,
- name % "full"
+ self.subject.setup + self.setup_src + self.main_src, name % "full"
)
self._generate_benchmark_func()
else:
- validate_source(
- self.subject.setup + self.setup_src,
- name % "setup"
- )
+ validate_source(self.subject.setup + self.setup_src, name % "setup")
def auto_number(self) -> int:
assert self.is_auto_number
@@ -341,22 +371,21 @@ def run(self) -> "TestCaseResult":
assert self.number is not None, "Number of runs not set."
datetime = utc_now()
with NoGC:
- return TestCaseResult(self, self.benchmark_func(self.number, TIMER), datetime)
+ return TestCaseResult(
+ self, self.benchmark_func(self.number, TIMER), datetime
+ )
def __repr__(self):
return f""
def to_json(self, id_only=False) -> Dict:
- data = {
- "benchmark_id": self.benchmark.id,
- "subject_id": self.subject.id
- }
+ data = {"benchmark_id": self.benchmark.id, "subject_id": self.subject.id}
if not id_only:
data.update({
"setup_src": self.setup_src,
"main_src": self.main_src,
"is_auto_number": self.is_auto_number,
- "number": self.number
+ "number": self.number,
})
return data
@@ -370,7 +399,9 @@ def from_json(cls, data: Dict, *, data_only=True, id_only=False) -> Self:
if id_only:
return benchmark.testcases[subject]
else:
- assert subject not in benchmark.testcases, f"TestCase of subject {subject} already exists in benchmark {benchmark}."
+ assert (
+ subject not in benchmark.testcases
+ ), f"TestCase of subject {subject} already exists in benchmark {benchmark}."
inst = cls(benchmark, subject)
inst.setup_src = tuple(data["setup_src"])
inst.main_src = tuple(data["main_src"])
@@ -382,7 +413,13 @@ def from_json(cls, data: Dict, *, data_only=True, id_only=False) -> Self:
class TestCaseResult:
- def __init__(self, testcase: TestCase, runtime: int, datetime: DateTime, sequence: Sequence[TestCase] = None):
+ def __init__(
+ self,
+ testcase: TestCase,
+ runtime: int,
+ datetime: DateTime,
+ sequence: Sequence[TestCase] = None,
+ ):
self.testcase = testcase
self.runtime = runtime
self.datetime = datetime
@@ -396,8 +433,11 @@ def to_json(self) -> Dict:
**self.testcase.to_json(id_only=True),
"runtime": self.runtime,
"timestamp": self.datetime.timestamp(),
- "sequence": ([case.to_json(id_only=True) for case in self.sequence]
- if self.sequence is not None else None)
+ "sequence": (
+ [case.to_json(id_only=True) for case in self.sequence]
+ if self.sequence is not None
+ else None
+ ),
}
@classmethod
@@ -406,15 +446,20 @@ def from_json(cls, data: Dict) -> Self:
testcase=TestCase.from_json(data, id_only=True),
runtime=data["runtime"],
datetime=from_utc_stamp(data["timestamp"]),
- sequence=(tuple(TestCase.from_json(d, id_only=True) for d in data["sequence"])
- if data["sequence"] is not None else None)
+ sequence=(
+ tuple(TestCase.from_json(d, id_only=True) for d in data["sequence"])
+ if data["sequence"] is not None
+ else None
+ ),
)
+
class BenchmarkMetadata:
def __init__(self):
log("Getting metadata...")
with indent_log():
import platform, cpuinfo, os
+
self.system = platform.system()
log(f"System: {self.system}")
self.ci = CI
@@ -444,7 +489,7 @@ def to_json(self) -> Dict:
"arch": self.arch,
"cpu": self.cpu,
"py_impl": self.py_impl,
- "py_ver": self.py_ver
+ "py_ver": self.py_ver,
}
@classmethod
@@ -458,6 +503,7 @@ def from_json(cls, data: Dict) -> Self:
inst.py_ver = data["py_ver"]
return inst
+
class BenchmarkResult:
def __init__(self, datetime: DateTime, metadata: BenchmarkMetadata = None):
self.datetime = datetime
@@ -471,7 +517,9 @@ def add_results(self, *results: TestCaseResult):
def benchmarks(self) -> Iterable[Benchmark]:
return iter_identity(r.testcase.benchmark for r in self.raw_results)
- def get_results(self, *, testcase: TestCase = None, subject: Subject = None) -> Generator[TestCaseResult, None, None]:
+ def get_results(
+ self, *, testcase: TestCase = None, subject: Subject = None
+ ) -> Generator[TestCaseResult, None, None]:
for result in self.raw_results:
if testcase is not None and result.testcase != testcase:
continue
@@ -483,14 +531,14 @@ def to_json(self) -> Dict:
return {
"timestamp": self.datetime.timestamp(),
"metadata": self.metadata.to_json(),
- "raw_results": tuple(r.to_json() for r in self.raw_results)
+ "raw_results": tuple(r.to_json() for r in self.raw_results),
}
@classmethod
def from_json(cls, data: Dict) -> Self:
inst = cls(
datetime=from_utc_stamp(data["timestamp"]),
- metadata=BenchmarkMetadata.from_json(data["metadata"])
+ metadata=BenchmarkMetadata.from_json(data["metadata"]),
)
inst.add_results(*(TestCaseResult.from_json(rd) for rd in data["raw_results"]))
return inst
@@ -519,11 +567,14 @@ def serialize(result: BenchmarkResult) -> Dict:
return {
"subjects": tuple(s.to_json() for s in Subject.instances),
"benchmarks": tuple(b.to_json() for b in Benchmark.instances),
- "result": result.to_json()
+ "result": result.to_json(),
}
+
def deserialize(data: Dict) -> BenchmarkResult:
- assert not Benchmark.instances or not Subject.instances, "Deserialization prohibited: environment not clean."
+ assert (
+ not Benchmark.instances or not Subject.instances
+ ), "Deserialization prohibited: environment not clean."
for d in data["subjects"]:
Subject.from_json(d)
@@ -531,10 +582,12 @@ def deserialize(data: Dict) -> BenchmarkResult:
Benchmark.from_json(d)
return BenchmarkResult.from_json(data["result"])
+
def clear_env():
Benchmark.instances.clear()
Subject.instances.clear()
+
def save_result(result: BenchmarkResult, file: Path):
"""Save the result and the current environment to a gzipped json file."""
file.parent.mkdir(parents=True, exist_ok=True)
@@ -542,6 +595,7 @@ def save_result(result: BenchmarkResult, file: Path):
with gzip.open(file, "wt", encoding="utf8") as f:
json.dump(serialize(result), f, ensure_ascii=False)
+
def load_result(file: Path) -> BenchmarkResult:
"""Restore the result and the environment from a gzipped json file."""
log(f"Loading from {file}...")
diff --git a/benchmark/benchmarks.py b/benchmark/benchmarks.py
index b42ab97..2fd71ab 100644
--- a/benchmark/benchmarks.py
+++ b/benchmark/benchmarks.py
@@ -1,3 +1,4 @@
+# fmt: off
# begin: noinspection PyUnresolvedReferences
# begin: noinspection PyUnusedLocal
# begin: noinspection PyUnboundLocalVariable
diff --git a/benchmark/charting.py b/benchmark/charting.py
index 514d654..7f67e7d 100644
--- a/benchmark/charting.py
+++ b/benchmark/charting.py
@@ -4,13 +4,15 @@
import matplotlib.pyplot as plt
from matplotlib import patches, transforms
+
def get_runtime(result: TestCaseResult) -> int:
return result.runtime
+
_TEXT = "#e6edf3"
-class BackgroundFancyBboxPatch(patches.FancyBboxPatch):
+class BackgroundFancyBboxPatch(patches.FancyBboxPatch):
def __init__(self, **kwargs):
super().__init__((0, 0), 1, 1, **kwargs)
@@ -28,14 +30,24 @@ def draw(self, renderer):
# Don't apply any transforms
self._draw_paths_with_artist_properties(
renderer,
- [(self.get_path(), transforms.IdentityTransform(),
- # Work around a bug in the PDF and SVG renderers, which
- # do not draw the hatches if the facecolor is fully
- # transparent, but do if it is None.
- self._facecolor if self._facecolor[3] else None)])
-
-
-def chart(result: BenchmarkResult, baseline: Subject, *, fig_height=5, subtitles: Iterable[str] = ()) -> plt.Figure:
+ [(
+ self.get_path(),
+ transforms.IdentityTransform(),
+ # Work around a bug in the PDF and SVG renderers, which
+ # do not draw the hatches if the facecolor is fully
+ # transparent, but do if it is None.
+ self._facecolor if self._facecolor[3] else None,
+ )],
+ )
+
+
+def chart(
+ result: BenchmarkResult,
+ baseline: Subject,
+ *,
+ fig_height=5,
+ subtitles: Iterable[str] = (),
+) -> plt.Figure:
fig: plt.Figure
ax: plt.Axes
fig, ax = plt.subplots(
@@ -48,28 +60,21 @@ def chart(result: BenchmarkResult, baseline: Subject, *, fig_height=5, subtitles
# Background
bg_edge = 2
bg = BackgroundFancyBboxPatch(
- facecolor="#161b22",
- linewidth=bg_edge,
- edgecolor="#30363d",
- figure=fig
+ facecolor="#161b22", linewidth=bg_edge, edgecolor="#30363d", figure=fig
)
bg.set_boxstyle("round", pad=-bg_edge / 2, rounding_size=fig.dpi * 0.5)
fig.patch = bg
# Title
- ax.set_title(f"Benchmark - {baseline.name} Implementation as Baseline\n"
- f"Operations Per Second"
- + (("\n" + "\n".join(subtitles)) if subtitles else ""),
- color=_TEXT)
+ ax.set_title(
+ f"Benchmark - {baseline.name} Implementation as Baseline\n"
+ f"Operations Per Second" + (("\n" + "\n".join(subtitles)) if subtitles else ""),
+ color=_TEXT,
+ )
ax.margins(x=0.01)
- ax.axhline(
- y=1,
- color=baseline.color,
- linestyle="--",
- label="Baseline"
- )
+ ax.axhline(y=1, color=baseline.color, linestyle="--", label="Baseline")
highest = 0
x_pos = 0
@@ -78,22 +83,19 @@ def chart(result: BenchmarkResult, baseline: Subject, *, fig_height=5, subtitles
subject_ops_per_sec: dict[Subject, float] = {}
for testcase in benchmark.testcases.values():
min_runtime = min(r.runtime for r in result.get_results(testcase=testcase))
- subject_ops_per_sec[testcase.subject] = testcase.number / (min_runtime / 1e9)
+ subject_ops_per_sec[testcase.subject] = testcase.number / (
+ min_runtime / 1e9
+ )
begin_pos = x_pos
for subject, ops_per_sec in sorted(
- subject_ops_per_sec.items(),
- key=lambda pair: pair[0].sort
+ subject_ops_per_sec.items(), key=lambda pair: pair[0].sort
):
height = ops_per_sec / (subject_ops_per_sec[baseline])
highest = max(height, highest)
rect = ax.bar(
- x=x_pos,
- height=height,
- width=1,
- label=subject.name,
- color=subject.color
+ x=x_pos, height=height, width=1, label=subject.name, color=subject.color
)
ax.bar_label(
rect,
@@ -101,7 +103,7 @@ def chart(result: BenchmarkResult, baseline: Subject, *, fig_height=5, subtitles
labels=[numerize(ops_per_sec, decimals=1)],
fontsize=10,
rotation=90,
- color=_TEXT
+ color=_TEXT,
)
x_pos += 1
@@ -112,11 +114,7 @@ def chart(result: BenchmarkResult, baseline: Subject, *, fig_height=5, subtitles
ax.margins(y=18 / highest / fig_height)
- ax.set_xticks(
- *zip(*x_ticks),
- rotation=15,
- color=_TEXT
- )
+ ax.set_xticks(*zip(*x_ticks), rotation=15, color=_TEXT)
[s.set_color("#ced0d6") for s in ax.spines.values()]
ax.yaxis.set_major_formatter(lambda x, pos: f"{x}x")
@@ -127,13 +125,14 @@ def chart(result: BenchmarkResult, baseline: Subject, *, fig_height=5, subtitles
handles, labels = ax.get_legend_handles_labels()
temp = {k: v for k, v in zip(labels, handles)}
legend = ax.legend(
- temp.values(), temp.keys(),
+ temp.values(),
+ temp.keys(),
loc="upper left",
labelcolor=_TEXT,
facecolor="#3d3f42",
edgecolor="#3d3f42",
- fancybox = True,
- framealpha = 0.5
+ fancybox=True,
+ framealpha=0.5,
)
legend.legendPatch.set_boxstyle("round", rounding_size=0.5)
diff --git a/benchmark/gen_charts.py b/benchmark/gen_charts.py
index c6fb824..eb07daa 100644
--- a/benchmark/gen_charts.py
+++ b/benchmark/gen_charts.py
@@ -11,7 +11,9 @@
files = [Path(f) for f in os.listdir(RESULTS) if f.endswith(".dat")]
# Latest to earliest
-files.sort(reverse=True, key=lambda f: datetime.datetime.strptime(f.stem, "%Y%m%d_%H-%M-%S"))
+files.sort(
+ reverse=True, key=lambda f: datetime.datetime.strptime(f.stem, "%Y%m%d_%H-%M-%S")
+)
for file in files:
log(f"{file}:")
@@ -27,9 +29,8 @@
f"{result.metadata.py_impl} {result.metadata.py_ver} · "
f"{result.metadata.system} · "
f"{result.metadata.cpu}"
- + (" (GitHub Actions)" if result.metadata.ci else "")
- ,
- )
+ + (" (GitHub Actions)" if result.metadata.ci else ""),
+ ),
)
chart.savefig(CHARTS / file.with_suffix(".svg"))
clear_env()
@@ -45,4 +46,6 @@
f.write("---\n")
f.write("\n")
for file in files:
- f.write(f"[![{file.stem}](./{file.with_suffix(".svg")})](./{file.with_suffix(".svg")})\n")
+ f.write(
+ f"[![{file.stem}](./{file.with_suffix(".svg")})](./{file.with_suffix(".svg")})\n"
+ )
diff --git a/benchmark/pure_python_impl.py b/benchmark/pure_python_impl.py
index 2d2563d..0932033 100644
--- a/benchmark/pure_python_impl.py
+++ b/benchmark/pure_python_impl.py
@@ -21,17 +21,9 @@ def __init__(self, x: float, y: float, z: float):
def __add__(self, other: Union["Vec3", float]):
if isinstance(other, Vec3):
- return Vec3(
- self.x + other.x,
- self.y + other.y,
- self.z + other.z
- )
+ return Vec3(self.x + other.x, self.y + other.y, self.z + other.z)
elif isinstance(other, float | int):
- return Vec3(
- self.x + other,
- self.y + other,
- self.z + other
- )
+ return Vec3(self.x + other, self.y + other, self.z + other)
else:
raise TypeError()
@@ -59,7 +51,7 @@ def __xor__(self, other: "Vec3"):
return Vec3(
self.y * other.z - self.z * other.y,
self.z * other.x - self.x * other.z,
- self.x * other.y - self.y * other.x
+ self.x * other.y - self.y * other.x,
)
else:
raise TypeError()
diff --git a/benchmark/utils.py b/benchmark/utils.py
index a3cc43c..bb51f87 100644
--- a/benchmark/utils.py
+++ b/benchmark/utils.py
@@ -3,6 +3,7 @@
import inspect
import itertools
from contextlib import contextmanager
+
# noinspection PyPep8Naming
from datetime import datetime as DateTime
from datetime import timezone
@@ -32,25 +33,29 @@
"indent_log",
"temp_log",
"iter_identity",
- "auto_number_series"
+ "auto_number_series",
)
SourceLines = Tuple[str, ...]
+
def auto_number_series() -> Generator[int, None, None]:
for expo in itertools.count(2):
# E12 series
for value in (10, 12, 15, 18, 22, 27, 33, 39, 47, 56, 68, 82):
- yield value * 10 ** expo
+ yield value * 10**expo
+
def indent(src: SourceLines) -> SourceLines:
- return tuple(" "*4 + line for line in src)
+ return tuple(" " * 4 + line for line in src)
+
def dedent(src: SourceLines) -> SourceLines:
for i in itertools.count():
- if any(line[i:i+1] != " " for line in src):
+ if any(line[i : i + 1] != " " for line in src):
return tuple(line[i:] for line in src)
+
def extract_source(func: FunctionType) -> SourceLines:
"""
Extract the source code of the body of func.
@@ -58,32 +63,38 @@ def extract_source(func: FunctionType) -> SourceLines:
"""
src = inspect.getsource(func).splitlines()
src = list(dedent(src))
- if src[0][0] == "@": # delete decorator
+ if src[0][0] == "@": # delete decorator
del src[0]
- del src[0] # delete function header
+ del src[0] # delete function header
return dedent(src)
+
def compile_src(src: SourceLines, name: str = ""):
return compile("\n".join(src), name, "exec", optimize=2)
+
def validate_source(src: SourceLines, name: str):
try:
exec(compile_src(src, name))
except Exception as e:
raise RuntimeError("Code validation failed.") from e
+
def extract_and_validate_source(func: FunctionType, name: str):
src = extract_source(func)
name += f"({func.__name__} in {func.__module__})"
validate_source(src, name)
return src
+
def utc_now() -> DateTime:
return DateTime.now(timezone.utc)
+
def from_utc_stamp(stamp: float) -> DateTime:
return DateTime.fromtimestamp(stamp, timezone.utc)
+
class _NoGC:
def __init__(self):
self._enter_depth = 0
@@ -98,16 +109,22 @@ def __exit__(self, *_):
if self._enter_depth == 0:
gc.enable()
+
NoGC = _NoGC()
T = TypeVar("T")
+
+
def _instance_from_id(cls: Type[T], identifier: str) -> Optional[T]:
for inst in cls.instances:
if inst.id == identifier:
return inst
return None
+
_log_indent = 0
+
+
@contextmanager
def indent_log(n=1):
global _log_indent
@@ -116,8 +133,11 @@ def indent_log(n=1):
_log_indent -= n
assert _log_indent >= 0
+
_log_chrs_stack = []
_log_temp_log = False
+
+
@contextmanager
def temp_log(disable=False):
if disable:
@@ -134,7 +154,10 @@ def temp_log(disable=False):
_should_indent = old_should_indent
_log_temp_log = False
+
_should_indent = True
+
+
def log(msg="", new_line=True, color: str = None):
if CI and _log_temp_log:
return
@@ -153,6 +176,7 @@ def log(msg="", new_line=True, color: str = None):
_should_indent = new_line
+
def iter_identity(it: Iterable[T]) -> Generator[T, None, None]:
got = set()
for item in it:
diff --git a/codegen/codegen_helper.py b/codegen/codegen_helper.py
index 7f07f9e..51bce93 100644
--- a/codegen/codegen_helper.py
+++ b/codegen/codegen_helper.py
@@ -89,9 +89,16 @@ def apply_params(in_line: str) -> str:
if ":" in line:
gen_pos = line.index("#:")
prefix = line[:gen_pos]
- expr = line[gen_pos + 7:]
+ expr = line[gen_pos + 7 :]
try:
- result = eval(expr, globals() if _globals is None else globals().copy().update(_globals))
+ result = eval(
+ expr,
+ (
+ globals()
+ if _globals is None
+ else globals().copy().update(_globals)
+ ),
+ )
generated = True
except Exception as e:
if isinstance(e, AssertionError):
@@ -138,6 +145,7 @@ def __init__(self, name: str, c_params: tuple[str], c_ret: str):
case _:
self.ret = c_ret
+
class _Overload:
def __init__(self, name: str):
self.name = name
@@ -146,8 +154,9 @@ def __init__(self, name: str):
def add(self, func: _Func):
if len(self._funcs) > 0:
- assert (Self in self._funcs[0].params) == (Self in func.params), \
- "function/method mismatch"
+ assert (Self in self._funcs[0].params) == (
+ Self in func.params
+ ), "function/method mismatch"
self._funcs.append(func)
self._possible_param_types_cache = None
@@ -181,10 +190,12 @@ def possible_param_types(self) -> Sequence[Sequence]:
if p not in self._possible_param_types_cache[i]:
self._possible_param_types_cache[i].append(p)
- for optional_param in self._possible_param_types_cache[self.min_params:]:
+ for optional_param in self._possible_param_types_cache[self.min_params :]:
optional_param.append(None)
- self._possible_param_types_cache = tuple(tuple(ts) for ts in self._possible_param_types_cache)
+ self._possible_param_types_cache = tuple(
+ tuple(ts) for ts in self._possible_param_types_cache
+ )
return self._possible_param_types_cache
@@ -204,19 +215,19 @@ def param_names(self) -> tuple[str]:
@staticmethod
def _type_check_expression(*types):
- if float in types and int in types: # py_float
+ if float in types and int in types: # py_float
assert len(types) == 2, "py_float type should be a discrete branch!"
# Insignificant optimization
# return "(PyFloat_CheckExact({value}) or PyLong_CheckExact({value}))"
return "(PyFloat_Check({value}) or PyLong_Check({value}))"
- elif int in types: # py_int
+ elif int in types: # py_int
assert len(types) == 1, "py_int type should be a discrete branch!"
# Insignificant optimization
# return "PyLong_CheckExact({value})"
return "PyLong_Check({value})"
else:
out = ""
- for i,t in enumerate(types):
+ for i, t in enumerate(types):
if type(t) is str:
out += f"isinstance({{value}}, {t})"
elif t is None:
@@ -269,29 +280,35 @@ def _gen_type_no_match_exception(self, param_types: tuple[tuple]) -> str:
type_strs.append(self._type_str(pts[0]))
else:
type_strs.append(f"({' | '.join(map(self._type_str, pts))})")
- text = (f"raise TypeError(\"No matching overload function for parameter types: "
- f"{', '.join(type_strs)}")
+ text = (
+ f'raise TypeError("No matching overload function for parameter types: '
+ f"{', '.join(type_strs)}"
+ )
if len(param_types) != self.max_params:
text += ", ..."
- text += "\")"
+ text += '")'
return text
def _gen_type_cast(self, var_name: str, var_types: tuple):
multiple_type_err = "A branch should have only one general(cast-able) type!"
- if float in var_types and int in var_types: # py_float
+ if float in var_types and int in var_types: # py_float
assert len(var_types) == 2, multiple_type_err
# Insignificant optimization
# return f"PyFloat_AS_DOUBLE({var_name}) if PyFloat_CheckExact({var_name}) else PyLong_AsDouble({var_name})"
return f"PyFloat_AsDouble({var_name})"
- elif int in var_types: # py_int
+ elif int in var_types: # py_int
assert len(var_types) == 1, multiple_type_err
return f"PyLong_AsLongLong({var_name})"
else:
assert len(var_types) == 1, multiple_type_err
- assert isinstance(var_types[0], str), f"Don't know how to cast to {var_types[0]}."
+ assert isinstance(
+ var_types[0], str
+ ), f"Don't know how to cast to {var_types[0]}."
return f"<{var_types[0]}> {var_name}"
- def _gen_dispatch_tree(self, params_types: tuple = None) -> tuple[Sequence[str], bool]:
+ def _gen_dispatch_tree(
+ self, params_types: tuple = None
+ ) -> tuple[Sequence[str], bool]:
"""Generate overload dispatch tree recursively."""
if params_types is None:
params_types = ((Self,),) if self._funcs[0].params[0] is Self else ()
@@ -302,12 +319,16 @@ def _gen_dispatch_tree(self, params_types: tuple = None) -> tuple[Sequence[str],
if len(funcs) == 0:
return [self._gen_type_no_match_exception(params_types)], False
if len(funcs) > 1:
- assert False, f"Multiple matching functions for: {', '.join(str(p) for p in params_first_types)}"
+ assert (
+ False
+ ), f"Multiple matching functions for: {', '.join(str(p) for p in params_first_types)}"
func = funcs[0]
- out = [f"{'self.' if self.is_method else ''}"
- f"{func.name}"
- f"({', '.join(self._gen_type_cast(pn, pts) for pts,pn in zip(params_types, self.param_names) if pn != 'self' and None not in pts)})"]
+ out = [
+ f"{'self.' if self.is_method else ''}"
+ f"{func.name}"
+ f"({', '.join(self._gen_type_cast(pn, pts) for pts,pn in zip(params_types, self.param_names) if pn != 'self' and None not in pts)})"
+ ]
if func.ret is None:
out.append("return")
else:
@@ -326,7 +347,9 @@ def _gen_dispatch_tree(self, params_types: tuple = None) -> tuple[Sequence[str],
if len(matches) == 0:
continue
for branch in branches:
- if set(matches) == set(self._func_from_params(params_first_types + (branch[0],))):
+ if set(matches) == set(
+ self._func_from_params(params_first_types + (branch[0],))
+ ):
branch.append(t)
break
else:
@@ -337,8 +360,10 @@ def _gen_dispatch_tree(self, params_types: tuple = None) -> tuple[Sequence[str],
any_hit = False
for i, t in enumerate(branches):
- out.append(f"{'if' if i == 0 else 'elif'} "
- f"{self._type_check_expression(*t).format(value=self.param_names[len(params_types)])}:")
+ out.append(
+ f"{'if' if i == 0 else 'elif'} "
+ f"{self._type_check_expression(*t).format(value=self.param_names[len(params_types)])}:"
+ )
branch_params = params_types + (t,)
dt, hit = self._gen_dispatch_tree(branch_params)
if hit:
@@ -347,12 +372,16 @@ def _gen_dispatch_tree(self, params_types: tuple = None) -> tuple[Sequence[str],
for l in dt:
out.append(f" {l}")
else:
- out.append(" " + self._gen_type_no_match_exception(branch_params))
+ out.append(
+ " " + self._gen_type_no_match_exception(branch_params)
+ )
out.append("else:")
expected_type_strs = [self._type_str(t) for t in possible_types]
expected_types_str = " | ".join(expected_type_strs)
- out.append(f" raise TypeError(f\"The {len(params_first_types) + 1}th parameter expected {expected_types_str}, "
- f"got {{{self.param_names[len(params_first_types)]}}}\")")
+ out.append(
+ f' raise TypeError(f"The {len(params_first_types) + 1}th parameter expected {expected_types_str}, '
+ f'got {{{self.param_names[len(params_first_types)]}}}")'
+ )
return out, any_hit
@@ -370,11 +399,14 @@ def gen_dispatcher(self) -> Sequence[str]:
ret_types = list(set(_Overload._type_str(f.ret) for f in self._funcs))
- lines = [f"def {self.name}({', '.join(('object ' if i > 0 else '') + p for i,p in enumerate(param_strs))}, /) -> {' | '.join(ret_types)}:"]
+ lines = [
+ f"def {self.name}({', '.join(('object ' if i > 0 else '') + p for i,p in enumerate(param_strs))}, /) -> {' | '.join(ret_types)}:"
+ ]
disp_lines, _ = self._gen_dispatch_tree()
lines += [f" {dl}" for dl in disp_lines]
return lines
+
def process_overloads(file: str) -> str:
import regex
@@ -398,7 +430,7 @@ def process_overloads(file: str) -> str:
r"\("
r"\s*(?:(?P\w+\s+\w+|self)(?:\s*,\s*(?P\w+\s+\w+))*)?\s*"
r"\)",
- line
+ line,
)
assert m is not None, line
name = m.group("name")
@@ -410,7 +442,7 @@ def process_overloads(file: str) -> str:
overload = overloads[name]
new_name = f"_{name}_{len(overload):d}"
name_span = m.span("name")
- lines[i] = line[:name_span[0]] + new_name + line[name_span[1]:]
+ lines[i] = line[: name_span[0]] + new_name + line[name_span[1] :]
overload.add(_Func(new_name, c_params, c_ret))
lines = [line for line in lines if "" not in line]
@@ -418,14 +450,17 @@ def process_overloads(file: str) -> str:
while True:
for i, line in enumerate(lines):
if ":" in line:
- m = regex.match(r"(?P\s*)#(?:\s|#)*:(?P\w+)", line)
+ m = regex.match(
+ r"(?P\s*)#(?:\s|#)*:(?P\w+)",
+ line,
+ )
assert m is not None
prefix = m.group("prefix")
name = m.group("name")
assert name in overloads
disp_lines = overloads[name].gen_dispatcher()
disp_lines = [prefix + disp_line for disp_line in disp_lines]
- lines[i:i + 1] = disp_lines
+ lines[i : i + 1] = disp_lines
break
else:
break
@@ -433,17 +468,24 @@ def process_overloads(file: str) -> str:
return "".join(f"{line}\n" for line in lines)
-def step_generate(template_file: str, output_file: str = None, write_file: bool = False, params: dict = None, _globals: dict = None, overload: bool = False):
+def step_generate(
+ template_file: str,
+ output_file: str = None,
+ write_file: bool = False,
+ params: dict = None,
+ _globals: dict = None,
+ overload: bool = False,
+):
print(f"Step Generate: {template_file}")
if output_file is None:
output_file = template_file
import os
+
if not os.path.exists("output"):
os.mkdir("output")
-
template = open(f"templates/{template_file}", encoding="utf8").read()
t = time.perf_counter()
if _globals is not None:
@@ -469,10 +511,13 @@ def step_generate(template_file: str, output_file: str = None, write_file: bool
def step_gen_stub(source_file: str, output_file: str):
import stub_generator
+
source = open(f"output/{source_file}", encoding="utf8").read()
t = time.perf_counter()
result = stub_generator.gen_stub(source)
- print(f"Step Gen Stub: {source_file} -> {output_file} completed in {time.perf_counter() - t:.3f}s")
+ print(
+ f"Step Gen Stub: {source_file} -> {output_file} completed in {time.perf_counter() - t:.3f}s"
+ )
with open(f"output/{output_file}", "w", encoding="utf8") as output:
output.write(result)
@@ -482,17 +527,16 @@ def step_cythonize(file: str):
import subprocess
print("########## Cythonize ##########")
- proc = subprocess.Popen((
- "cythonize.exe",
- "-a",
- "-i",
- f"output/{file}.pyx"
- ), stdout=sys.stdout, stderr=sys.stderr)
+ proc = subprocess.Popen(
+ ("cythonize.exe", "-a", "-i", f"output/{file}.pyx"),
+ stdout=sys.stdout,
+ stderr=sys.stderr,
+ )
t = time.perf_counter()
proc.wait()
print(
f"Cythonize finished in {time.perf_counter() - t}s with exit code {proc.returncode}",
- file=sys.stdout if proc.returncode == 0 else sys.stderr
+ file=sys.stdout if proc.returncode == 0 else sys.stderr,
)
@@ -506,4 +550,3 @@ def step_move_to_dest(final_dest: str, file_prefix: str, file_suffix: str):
dest = os.path.join(final_dest, file)
shutil.copy(path, dest)
print(f"Coping {file} to {dest}")
-
diff --git a/codegen/gen_all.py b/codegen/gen_all.py
index e4e8ad9..9695bc3 100644
--- a/codegen/gen_all.py
+++ b/codegen/gen_all.py
@@ -3,24 +3,35 @@
def main():
import vector_codegen
+
codegen.step_generate(
"_spatium.pyx",
write_file=True,
- _globals={vector_codegen.__name__: vector_codegen})
+ _globals={vector_codegen.__name__: vector_codegen},
+ )
codegen.step_gen_stub("_spatium.pyx", "_spatium.pyi")
import sys
+
if "--install" in sys.argv:
- print("#"*15 + " Install " + "#"*15)
+ print("#" * 15 + " Install " + "#" * 15)
codegen.step_move_to_dest("../src/spatium/", "_spatium", ".pyx")
codegen.step_move_to_dest("../src/spatium/", "_spatium", ".pyi")
import sys
import subprocess
- print("#"*15 + " pip uninstall spatium -y " + "#"*15)
- subprocess.call(f"{sys.executable} -m pip uninstall spatium -y", stdout=sys.stdout)
- print("#"*15 + " pip install -v -v -v .. " + "#"*15)
- subprocess.call(f"{sys.executable} -m pip install -v -v -v ..", stdout=sys.stdout, stderr=sys.stdout)
-if __name__ == '__main__':
+ print("#" * 15 + " pip uninstall spatium -y " + "#" * 15)
+ subprocess.call(
+ f"{sys.executable} -m pip uninstall spatium -y", stdout=sys.stdout
+ )
+ print("#" * 15 + " pip install -v -v -v .. " + "#" * 15)
+ subprocess.call(
+ f"{sys.executable} -m pip install -v -v -v ..",
+ stdout=sys.stdout,
+ stderr=sys.stdout,
+ )
+
+
+if __name__ == "__main__":
main()
diff --git a/codegen/stub_generator.py b/codegen/stub_generator.py
index 9ee28de..31e7dca 100644
--- a/codegen/stub_generator.py
+++ b/codegen/stub_generator.py
@@ -20,6 +20,7 @@ def convert_type(org: str) -> str:
types.append(f'"{raw}"')
return types[0] if len(types) == 1 else f"Union[{', '.join(types)}]"
+
class StubProperty:
def __init__(self, name: str, ptype: str, mutable: bool = True):
self.name = name
@@ -29,10 +30,7 @@ def __init__(self, name: str, ptype: str, mutable: bool = True):
self.setter_doc = []
def stub(self) -> list[str]:
- stub = [
- "@property",
- f"def {self.name}(self) -> {self.type}:"
- ]
+ stub = ["@property", f"def {self.name}(self) -> {self.type}:"]
if self.getter_doc:
stub.extend(self.getter_doc)
stub.append(" ...")
@@ -53,7 +51,13 @@ def stub(self) -> list[str]:
class StubMethod:
class Param:
- def __init__(self, name: str, ptype: str, default: Optional[str] = None, const_mapping: dict[str, str] = None):
+ def __init__(
+ self,
+ name: str,
+ ptype: str,
+ default: Optional[str] = None,
+ const_mapping: dict[str, str] = None,
+ ):
self.name = name
self.ptype = ptype
self.default = default
@@ -68,7 +72,14 @@ def stub(self) -> str:
stub += f" = {default}"
return stub
- def __init__(self, name: str, rtype: str, params: Sequence[Param | str], is_cdef: bool, is_static: bool):
+ def __init__(
+ self,
+ name: str,
+ rtype: str,
+ params: Sequence[Param | str],
+ is_cdef: bool,
+ is_static: bool,
+ ):
self.name = name
self.rtype = rtype
self.params = params
@@ -88,7 +99,9 @@ def stub(self, name_override: Optional[str] = None) -> list[str]:
else:
param_list.append(p.stub())
- stub.append(f"def {name_override or self.name}({', '.join(param_list)}) -> {self.rtype}:")
+ stub.append(
+ f"def {name_override or self.name}({', '.join(param_list)}) -> {self.rtype}:"
+ )
if self.docstring:
stub.extend(self.docstring)
stub.append(" ...")
@@ -97,6 +110,7 @@ def stub(self, name_override: Optional[str] = None) -> list[str]:
return stub
+
class StubClass:
def __init__(self, name: str):
self.name = name
@@ -112,7 +126,7 @@ def indent(src: list[str]) -> list[str]:
stub = [
"# noinspection SpellCheckingInspection,GrazieInspection",
- f"class {self.name}:"
+ f"class {self.name}:",
]
if self.docstring:
stub.extend(indent(self.docstring))
@@ -135,7 +149,9 @@ def indent(src: list[str]) -> list[str]:
del methods[ol_name]
assert ol_method.is_cdef
stub.extend(indent(["@overload"]))
- stub.extend(indent(ol_method.stub(name_override=method.name)))
+ stub.extend(
+ indent(ol_method.stub(name_override=method.name))
+ )
else:
break
# Not Overloaded
@@ -153,9 +169,10 @@ def gen_stub(source: str) -> str:
in_docstring = False
docstring_dest: Optional[list[str]] = None
+
def add_docstring_line(doc_line: str):
assert docstring_dest is not None
- docstring_dest.append(doc_line[4:]) # remove one indent level
+ docstring_dest.append(doc_line[4:]) # remove one indent level
print("gen_stub: reading source...")
source_lines = source.splitlines(keepends=False)
@@ -181,7 +198,10 @@ def add_docstring_line(doc_line: str):
docstring_dest = current_class.docstring
decorators.clear()
# Property
- elif m := regex.match(r"\s+cdef\s+public\s+(?P\w+)\s+(?P\w+)(?:\s*,\s*(?P\w+))*", line):
+ elif m := regex.match(
+ r"\s+cdef\s+public\s+(?P\w+)\s+(?P\w+)(?:\s*,\s*(?P\w+))*",
+ line,
+ ):
ptype = convert_type(m.group("type"))
for pname in m.captures("names"):
current_class.properties[pname] = StubProperty(pname, ptype)
@@ -189,32 +209,32 @@ def add_docstring_line(doc_line: str):
decorators.clear()
# Method (including @property)
elif (
- (cdef_m := regex.match( # cdef methods can not be static, self always present
- r"\s+cdef\s+"
- r"(?:inline\s+)?"
- r"(?P\w+)?\s+"
- r"(?P\w+)\s*"
- r"\(\s*"
- r"(?:self\s*)?"
- r"(?:,\s*(?P\w+\s+\w+)\s*)*"
- r"\)\s*"
- r"(?:noexcept)?\s*"
- r":",
- line
- ))
- or
- (def_m := regex.match(
- r"\s+def\s+"
- r"(?P\w+)\s*"
- r"\(\s*"
- r"(?:(?:self\s*)|(?&_param))?"
- r"(?:,\s*(?P<_param>(?P\w+\s+\w+(?:\s*=\s*[^,)]+)?)|(?P/))\s*)*"
- r"\)\s*"
- r"(?:->\s*(?P[^:]+))?\s*"
- r":",
- line
- ))
- ):
+ cdef_m := regex.match( # cdef methods can not be static, self always present
+ r"\s+cdef\s+"
+ r"(?:inline\s+)?"
+ r"(?P\w+)?\s+"
+ r"(?P\w+)\s*"
+ r"\(\s*"
+ r"(?:self\s*)?"
+ r"(?:,\s*(?P\w+\s+\w+)\s*)*"
+ r"\)\s*"
+ r"(?:noexcept)?\s*"
+ r":",
+ line,
+ )
+ ) or (
+ def_m := regex.match(
+ r"\s+def\s+"
+ r"(?P\w+)\s*"
+ r"\(\s*"
+ r"(?:(?:self\s*)|(?&_param))?"
+ r"(?:,\s*(?P<_param>(?P\w+\s+\w+(?:\s*=\s*[^,)]+)?)|(?P/))\s*)*"
+ r"\)\s*"
+ r"(?:->\s*(?P[^:]+))?\s*"
+ r":",
+ line,
+ )
+ ):
m = cdef_m or def_m
is_cdef = cdef_m is not None
name = m.group("name")
@@ -223,7 +243,9 @@ def add_docstring_line(doc_line: str):
# property getter
if "property" in decorators:
assert name not in current_class.properties
- current_class.properties[name] = prop = StubProperty(name, convert_type(m.group("return")), mutable=False)
+ current_class.properties[name] = prop = StubProperty(
+ name, convert_type(m.group("return")), mutable=False
+ )
docstring_dest = prop.getter_doc
# proeprty setter
elif setter_dec := [d for d in decorators if d.endswith(".setter")]:
@@ -247,18 +269,30 @@ def add_docstring_line(doc_line: str):
# default values are not supported for cdef methods yet
pm = regex.fullmatch(r"(?P\w+)\s+(?P\w+)", param)
assert pm is not None
- params.append(StubMethod.Param(pm.group("name"), convert_type(pm.group("type")), const_mapping=const_mapping))
+ params.append(
+ StubMethod.Param(
+ pm.group("name"),
+ convert_type(pm.group("type")),
+ const_mapping=const_mapping,
+ )
+ )
else:
- pm = regex.fullmatch(r"(?P\w+)\s+(?P\w+)(?:\s*=\s*(?P[^,]+))?", param)
+ pm = regex.fullmatch(
+ r"(?P\w+)\s+(?P\w+)(?:\s*=\s*(?P[^,]+))?",
+ param,
+ )
assert pm is not None
- params.append(StubMethod.Param(pm.group("name"), convert_type(pm.group("type")), pm.group("default"), const_mapping=const_mapping))
+ params.append(
+ StubMethod.Param(
+ pm.group("name"),
+ convert_type(pm.group("type")),
+ pm.group("default"),
+ const_mapping=const_mapping,
+ )
+ )
current_class.methods[name] = method = StubMethod(
- name,
- rtype,
- params,
- is_cdef,
- is_static="staticmethod" in decorators
+ name, rtype, params, is_cdef, is_static="staticmethod" in decorators
)
docstring_dest = method.docstring
@@ -280,7 +314,7 @@ def add_docstring_line(doc_line: str):
out_lines = [
"# noinspection PyUnresolvedReferences",
"from typing import overload, Self, Any, Union",
- ""
+ "",
]
for cls in classes:
print(f"gen_stub: generating class {cls.name}")
@@ -289,7 +323,7 @@ def add_docstring_line(doc_line: str):
return "".join(f"{line}\n" for line in out_lines)
-if __name__ == '__main__':
+if __name__ == "__main__":
result = gen_stub(open("output/_spatium.pyx", encoding="utf8").read())
with open("output/_spatium.pyi", "w") as f:
f.write(result)
diff --git a/codegen/vector_codegen.py b/codegen/vector_codegen.py
index e646d0f..7a882da 100644
--- a/codegen/vector_codegen.py
+++ b/codegen/vector_codegen.py
@@ -65,7 +65,9 @@ def gen_combination_constructors(dims: int, vtype: Type) -> str:
# Generate params
self_dim = 0
for in_dims in combination:
- param_types.append(f"{get_c_type(vtype) if in_dims == 1 else get_vec_class_name(in_dims, vtype)}")
+ param_types.append(
+ f"{get_c_type(vtype) if in_dims == 1 else get_vec_class_name(in_dims, vtype)}"
+ )
param_name = ""
for in_dim in range(in_dims):
param_name += DIMS[self_dim]
@@ -79,12 +81,16 @@ def gen_combination_constructors(dims: int, vtype: Type) -> str:
if in_dims == 1:
assigns.append(f"self.{DIMS[self_dim]} = {param_name}")
else:
- assigns.append(f"self.{DIMS[self_dim]} = {param_name}.{DIMS[in_dim]}")
+ assigns.append(
+ f"self.{DIMS[self_dim]} = {param_name}.{DIMS[in_dim]}"
+ )
self_dim += 1
func = "#\n"
- func += (f"cdef inline void __init__(self, "
- f"{', '.join(f'{t} {n}' for t, n in zip(param_types, param_names))}) noexcept:\n")
+ func += (
+ f"cdef inline void __init__(self, "
+ f"{', '.join(f'{t} {n}' for t, n in zip(param_types, param_names))}) noexcept:\n"
+ )
docstring = f"Create a {dims}D vector from "
for i, param_dims in enumerate(combination):
@@ -156,7 +162,9 @@ def _gen_swizzle_get(swizzle: str, vtype: Type) -> str:
def _gen_swizzle_set(swizzle: str, vtype: Type):
out = f"@{swizzle}.setter\n"
- out += f"def {swizzle}(self, {get_vec_class_name(len(swizzle), vtype)} vec) -> None:\n"
+ out += (
+ f"def {swizzle}(self, {get_vec_class_name(len(swizzle), vtype)} vec) -> None:\n"
+ )
docstring = f"Set the value of the {', '.join(swizzle.upper())}"
docstring = docstring[:-3] + " and " + docstring[-1]
@@ -196,10 +204,7 @@ def gen_swizzle_properties(dims: int, vtype: Type) -> str:
def gen_for_each_dim(template: str, dims: int, join="\n") -> str:
out = ""
for i, dim in enumerate(range(dims)):
- out += template.format_map({
- "dim": DIMS[dim],
- "index": i
- })
+ out += template.format_map({"dim": DIMS[dim], "index": i})
if dim != dims - 1:
out += join
return out
@@ -216,8 +221,13 @@ def gen_repr(dims: int, vtype: Type) -> str:
out += ')"'
return out
+
def gen_common_binary_and_inplace_op(op: str, name: str, readable_name: str) -> str:
- return from_template(open("templates/common_binary_and_inplace_op.pyx").read(), {"Op": op, "OpName": name, "OpReadableName": readable_name})
+ return from_template(
+ open("templates/common_binary_and_inplace_op.pyx").read(),
+ {"Op": op, "OpName": name, "OpReadableName": readable_name},
+ )
+
def gen_item_op(dims: int, op: str) -> str:
out = ""
@@ -226,9 +236,10 @@ def gen_item_op(dims: int, op: str) -> str:
out += f"key == {dim}:\n"
out += f" {op.format(dim=DIMS[dim])}\n"
out += "else:\n"
- out += " raise KeyError(f\"_VecClassName_ index out of range: {key}\")"
+ out += ' raise KeyError(f"_VecClassName_ index out of range: {key}")'
return out
+
def gen_iterator_next(dims: int) -> str:
out = ""
for dim in range(dims):
diff --git a/pyproject.toml b/pyproject.toml
index eef79ae..4db87d0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -63,3 +63,9 @@ archs = ["x86_64", "i686"]
[tool.cibuildwheel.macos]
archs = ["x86_64", "arm64"]
+
+[tool.black]
+preview = true
+enable-unstable-feature = [
+ "hug_parens_with_braces_and_square_brackets"
+]
diff --git a/setup.py b/setup.py
index 827a79c..2499a2d 100644
--- a/setup.py
+++ b/setup.py
@@ -13,8 +13,7 @@
),
],
packages=find_packages(
- where="src",
- exclude=["tests", "spatium/*.c", "spatium/*.cpp"]
+ where="src", exclude=["tests", "spatium/*.c", "spatium/*.cpp"]
),
- package_dir={'': 'src'},
+ package_dir={"": "src"},
)
diff --git a/tests/test_transform_2d.py b/tests/test_transform_2d.py
index 930b8df..490f6c2 100644
--- a/tests/test_transform_2d.py
+++ b/tests/test_transform_2d.py
@@ -10,6 +10,7 @@ def test_normal_constructor_and_get_components():
assert t.y == Vec2(3, 4)
assert t.origin == Vec2(5, 6)
+
def test_comparison():
def gen_diff_at(index, n):
return Transform2D(*[n if i == index else 0.5 for i in range(6)])
@@ -20,6 +21,7 @@ def gen_diff_at(index, n):
result = gen_diff_at(a, num) == gen_diff_at(b, num)
assert result if a == b else not result
+
def test_set_components():
t = Transform2D()
t.x = Vec2(1, 2)
@@ -27,11 +29,16 @@ def test_set_components():
t.origin = Vec2(5, 6)
assert t == Transform2D(1, 2, 3, 4, 5, 6)
+
def test_empty_constructor():
assert Transform2D() == Transform2D(1, 0, 0, 1, 0, 0)
+
def test_component_constructor():
- assert Transform2D(Vec2(1, 2), Vec2(3, 4), Vec2(5, 6)) == Transform2D(1, 2, 3, 4, 5, 6)
+ assert Transform2D(Vec2(1, 2), Vec2(3, 4), Vec2(5, 6)) == Transform2D(
+ 1, 2, 3, 4, 5, 6
+ )
+
def test_copy_constructor():
t = Transform2D(*range(6))
@@ -39,19 +46,23 @@ def test_copy_constructor():
assert id(t) != id(tc)
assert t == tc
+
def test_translation_constructor():
t = Transform2D.translating(Vec2(1, 2))
assert t == Transform2D(1, 0, 0, 1, 1, 2)
+
def test_rotating_constructor():
r = 1.23
t = Transform2D.rotating(r)
assert t.is_close(Transform2D(cos(r), sin(r), -sin(r), cos(r), 0, 0), 1e-7)
+
def test_scaling_constructor():
t = Transform2D.scaling(Vec2(1, 2))
assert t == Transform2D(1, 0, 0, 2, 0, 0)
+
def test_vector_xform():
t = Transform2D(1, 2, 3, 4, 5, 6)
v = Vec2(1, 2)
@@ -59,11 +70,13 @@ def test_vector_xform():
assert (t * v).is_close(ans)
assert t(v).is_close(ans)
+
def test_vector_inverse_xform():
print(Vec2(1, 2) * Transform2D(1, 2, 3, 4, 5, 6))
print((~Transform2D(*range(1, 7))) * Vec2(1, 2))
assert (Vec2(1, 2) * Transform2D(1, 2, 3, 4, 5, 6)).is_close(Vec2(-12, -28))
+
def test_matmul():
t1 = Transform2D(1, 2, 3, 4, 5, 6)
t2 = Transform2D(3, 1, 2, 6, 5, 4)
@@ -71,6 +84,7 @@ def test_matmul():
assert (t1 @ t2).is_close(ans)
assert t1(t2).is_close(ans)
+
def test_imatmul():
t1 = Transform2D(1, 2, 3, 4, 5, 6)
t2 = Transform2D(3, 1, 2, 6, 5, 4)
@@ -78,8 +92,12 @@ def test_imatmul():
t3 @= t2
assert t2 @ t1 == t3
+
def test_determinant():
assert isclose(Transform2D(1.6, 2.5, 3.4, 4.3, 5.2, 6.1).determinant, -1.62)
+
def test_inverse():
- assert (~Transform2D(1, 2, 3, 4, 5, 6)).is_close(Transform2D(-2, 1, 1.5, -0.5, 1, -2))
+ assert (~Transform2D(1, 2, 3, 4, 5, 6)).is_close(
+ Transform2D(-2, 1, 1.5, -0.5, 1, -2)
+ )
diff --git a/tests/test_transform_3d.py b/tests/test_transform_3d.py
index 5d590da..7fc310b 100644
--- a/tests/test_transform_3d.py
+++ b/tests/test_transform_3d.py
@@ -11,6 +11,7 @@ def test_normal_constructor_and_get_components():
assert t.z == Vec3(7, 8, 9)
assert t.origin == Vec3(10, 11, 12)
+
def test_comparison():
def gen_diff_at(index, n):
return Transform3D(*[n if i == index else 0.5 for i in range(12)])
@@ -21,6 +22,7 @@ def gen_diff_at(index, n):
result = gen_diff_at(a, num) == gen_diff_at(b, num)
assert result if a == b else not result
+
def test_set_components():
t = Transform3D()
t.x = Vec3(1, 2, 3)
@@ -29,11 +31,16 @@ def test_set_components():
t.origin = Vec3(10, 11, 12)
assert t == Transform3D(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
+
def test_empty_constructor():
assert Transform3D() == Transform3D(1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0)
+
def test_component_constructor():
- assert Transform3D(Vec3(1, 2, 3), Vec3(4, 5, 6), Vec3(7, 8, 9), Vec3(10, 11, 12)) == Transform3D(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
+ assert Transform3D(
+ Vec3(1, 2, 3), Vec3(4, 5, 6), Vec3(7, 8, 9), Vec3(10, 11, 12)
+ ) == Transform3D(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)
+
def test_copy_constructor():
t = Transform3D(*range(12))
@@ -41,18 +48,38 @@ def test_copy_constructor():
assert id(t) != id(tc)
assert t == tc
+
def test_translation_constructor():
t = Transform3D.translating(Vec3(1, 2, 3))
assert t == Transform3D(1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 2, 3)
+
def test_rotating_constructor():
t = Transform3D.rotating(Vec3(1.85, 9.23, 3.42).normalized, 1.98)
- assert t.is_close(Transform3D(-0.350185, 0.551229, -0.757309, -0.075323, 0.789313, 0.609353, 0.933647, 0.270429, -0.234886, 0, 0, 0), rel_tol=1e-5)
+ assert t.is_close(
+ Transform3D(
+ -0.350185,
+ 0.551229,
+ -0.757309,
+ -0.075323,
+ 0.789313,
+ 0.609353,
+ 0.933647,
+ 0.270429,
+ -0.234886,
+ 0,
+ 0,
+ 0,
+ ),
+ rel_tol=1e-5,
+ )
+
def test_scaling_constructor():
t = Transform3D.scaling(Vec3(1, 2, 3))
assert t == Transform3D(1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0)
+
def test_vector_xform():
t = Transform3D(*range(1, 13))
v = Vec3(1, 2, 3)
@@ -60,9 +87,11 @@ def test_vector_xform():
assert (t * v).is_close(ans)
assert t(v).is_close(ans)
+
def test_vector_inverse_xform():
assert (Vec3(1, 2, 3) * Transform3D(*range(1, 13))).is_close(Vec3(-54, -135, -216))
+
def test_matmul():
t1 = Transform3D(*range(1, 13))
t2 = Transform3D(10, 8, 4, 6, 12, 7, 5, 3, 2, 1, 11, 9)
@@ -70,6 +99,7 @@ def test_matmul():
assert (t1 @ t2).is_close(ans)
assert t1(t2).is_close(ans)
+
def test_imatmul():
t1 = Transform3D(*range(1, 13))
t2 = Transform3D(10, 8, 4, 6, 12, 7, 5, 3, 2, 1, 11, 9)
@@ -77,8 +107,13 @@ def test_imatmul():
t3 @= t2
assert t2 @ t1 == t3
+
def test_determinant():
- assert isclose(Transform3D(1.12, 2.11, 0, 0, 5.8, 6.7, 7.6, 8.5, 0, 1, 2, 3).determinant, 43.6572)
+ assert isclose(
+ Transform3D(1.12, 2.11, 0, 0, 5.8, 6.7, 7.6, 8.5, 0, 1, 2, 3).determinant,
+ 43.6572,
+ )
+
def test_inverse():
t = Transform3D(1, 2, 3, 6, 5, 4, 7, 9, 8, 1, 2, 3)
diff --git a/tests/test_vector.py b/tests/test_vector.py
index 311f949..3d9c2dc 100644
--- a/tests/test_vector.py
+++ b/tests/test_vector.py
@@ -8,6 +8,7 @@ def test_empty_constructor():
vec = Vec3(0)
assert vec.x == vec.y == vec.z == 0
+
def test_normal_constructor():
vec = Vec3(1, 2, 3)
assert vec.x == 1
@@ -19,6 +20,7 @@ def test_normal_constructor():
assert vec.y == 2
assert vec.z == 3
+
def test_combined_constructors():
vec = Vec4(Vec2(1, 2), Vec2(3, 4))
assert vec.x == 1
@@ -31,16 +33,19 @@ def test_combined_constructors():
assert vec.y == 2
assert vec.z == 3
+
def test_type_conversion_constructor():
vec = Vec3(Vec3i(1, 2, 3))
assert vec.x == 1
assert vec.y == 2
assert vec.z == 3
+
def test_repr():
assert repr(Vec3(1, 2, 3)) == "Vec3(1.0, 2.0, 3.0)"
assert repr(Vec3i(1, 2, 3)) == "Vec3i(1, 2, 3)"
+
def test_comparison():
a = Vec3(1, 2, 3)
b = Vec3(3, 2, 1)
@@ -50,6 +55,7 @@ def test_comparison():
assert (a == b) == False
assert (a != c) == False
+
def test_swizzle():
v2 = Vec2(1, 2)
v3 = Vec3(1, 2, 3)
@@ -60,6 +66,7 @@ def test_swizzle():
assert v3.ylo == Vec3(2, 1, 0)
assert v4.wzyx == Vec4(4, 3, 2, 1)
+
def test_copy():
a = Vec3(1, 2, 3)
b = +a
@@ -67,62 +74,75 @@ def test_copy():
assert a == Vec3(1, 2, 3)
assert b == Vec3(5, 2, 3)
+
def test_neg():
a = Vec3(1, 2, 3)
b = -a
assert b == Vec3(-1, -2, -3)
+
def test_add():
a = Vec3(1, 2, 3)
b = Vec3(4, 5, 6)
assert a + b == Vec3(5, 7, 9)
assert a == Vec3(1, 2, 3)
+
def test_iadd():
a = Vec3(1, 2, 3)
b = Vec3(4, 5, 6)
a += b
assert a == Vec3(5, 7, 9)
+
def test_add_float():
a = Vec3(1, 2, 3)
assert a + 1.5 == Vec3(2.5, 3.5, 4.5)
+
def test_add_int():
a = Vec3(1, 2, 3)
assert a + 1 == Vec3(2, 3, 4)
+
def test_iadd_float():
a = Vec3(1, 2, 3)
a += 1.5
assert a == Vec3(2.5, 3.5, 4.5)
+
def test_sub():
a = Vec3(1, 2, 3)
b = Vec3(6, 5, 4)
assert a - b == Vec3(-5, -3, -1)
+
def test_mul():
a = Vec3(1, 2, 3)
b = Vec3(0, 2, 4)
assert a * b == Vec3(0, 4, 12)
+
def test_div():
a = Vec3(1, 2, 3)
b = Vec3(0.5, 2, 1.5)
assert a / b == Vec3(2, 1, 2)
+
def test_div_zero():
a = Vec3(1, 2, 3)
b = Vec3(0, 1, 2)
import math
+
assert a / b == Vec3(math.inf, 2, 1.5)
+
def test_dot():
a = Vec3(1, 2, 3)
b = Vec3(2, 1, 3)
assert a @ b == 13.0
+
def test_cross():
a = Vec3(1, 2, 3)
b = Vec3(3, 7, 5)
@@ -134,6 +154,7 @@ def test_cross():
# noinspection PyStatementEffect
a ^ b
+
def test_len():
assert len(Vec2()) == 2
assert len(Vec3()) == 3
@@ -142,9 +163,11 @@ def test_len():
assert len(Vec3i()) == 3
assert len(Vec4i()) == 4
+
def test_iter():
for i, value in enumerate(Vec3(1, 2, 3)):
- assert value == i+1
+ assert value == i + 1
+
def test_unpack():
x, y, z = Vec3(1, 2, 3)
@@ -152,35 +175,43 @@ def test_unpack():
assert y == 2
assert z == 3
+
def test_getitem():
assert Vec3(0, 1.5, 0)[1] == 1.5
+
def test_setitem():
v = Vec3(1, 3, 3)
v[1] = 2
assert v == Vec3(1, 2, 3)
+
def test_length():
assert math.isclose(Vec3(1, 2, 3).length, 3.7416573867739413)
+
def test_length_squared():
assert Vec3(1, 2, 3).length_sqr == 14
+
def test_distance_to():
a = Vec3(1, 2, 3)
b = Vec3(4, 5, 6)
assert math.isclose(a.distance_to(b), (a - b).length)
+
def test_distance_to_with_or_operator():
a = Vec3(1, 2, 3)
b = Vec3(4, 5, 6)
assert math.isclose(a | b, (a - b).length)
+
def test_distance_squared_to():
a = Vec3(1, 2, 3)
b = Vec3(4, 5, 6)
assert a.distance_sqr_to(b) == 27
+
def test_normalized():
a = Vec3(1, 2, 3)
a = a.normalized
@@ -188,8 +219,10 @@ def test_normalized():
assert math.isclose(a.y, 0.5345224838248488)
assert math.isclose(a.z, 0.8017837257372732)
+
def test_big_int():
- Vec2i(2**63-1)
+ Vec2i(2**63 - 1)
+
def test_float_precision():
assert Vec2(math.ulp(0)).x == math.ulp(0)