diff --git a/README.md b/README.md index 47e176d..0885b66 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@
![Language](https://img.shields.io/badge/Language-Cython-FEDF5B) ![Python Implementation](https://img.shields.io/pypi/implementation/spatium?label=Implementation) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Codegen](https://github.com/shBLOCK/spatium/actions/workflows/codegen.yml/badge.svg)](https://github.com/shBLOCK/spatium/actions/workflows/codegen.yml) [![Tests](https://github.com/shBLOCK/spatium/actions/workflows/tests.yml/badge.svg)](https://github.com/shBLOCK/spatium/actions/workflows/tests.yml) diff --git a/benchmark/benchmarking.py b/benchmark/benchmarking.py index bb053c6..6c57884 100644 --- a/benchmark/benchmarking.py +++ b/benchmark/benchmarking.py @@ -3,12 +3,23 @@ import json import math import time + # noinspection PyPep8Naming from datetime import datetime as DateTime import itertools from pathlib import Path from types import FunctionType -from typing import Self, Callable, Dict, overload, Iterable, Sequence, Optional, Generator, Any +from typing import ( + Self, + Callable, + Dict, + overload, + Iterable, + Sequence, + Optional, + Generator, + Any, +) import colorama @@ -30,7 +41,7 @@ "log", "temp_log", "indent_log", - "CI" + "CI", ) TIMER = time.perf_counter_ns @@ -38,6 +49,7 @@ AUTO_NUMBER_TARGET_TIME = 0.5e9 # AUTO_NUMBER_TARGET_TIME = 0.005e9 + class Subject: """ Define a benchmark subject. (e.g. spatium, Numpy, ...) @@ -45,10 +57,13 @@ class Subject: the decorated function will become the common setup procedure of this subject. (e.g. `from spatium import Vec3`) """ + instances: list["Subject"] = [] get_instance = classmethod(_instance_from_id) - def __init__(self, name: str, *, color: str = None, identifier: str = None, sort: int = None): + def __init__( + self, name: str, *, color: str = None, identifier: str = None, sort: int = None + ): self.setup: SourceLines = None self.id: str = identifier self.name = name @@ -59,8 +74,7 @@ def __init__(self, name: str, *, color: str = None, identifier: str = None, sort def __call__(self, setup: FunctionType) -> Self: self.id = setup.__name__ self.setup = extract_and_validate_source( - setup, - f"" + setup, f"" ) return self @@ -73,7 +87,7 @@ def to_json(self) -> Dict: "name": self.name, "color": self.color, "setup": self.setup, - "sort": self.sort + "sort": self.sort, } @classmethod @@ -82,16 +96,18 @@ def from_json(cls, data: Dict) -> Self: name=data["name"], color=data["color"], identifier=data["id"], - sort=data["sort"] + sort=data["sort"], ) inst.setup = tuple(data["setup"]) return inst + class Benchmark: """ Represents a benchmark, can have multiple test cases for multiple subject. Best be instantiated in a class as a class property, so that id can be automatically assigned. """ + instances: list["Benchmark"] = [] get_instance = classmethod(_instance_from_id) @@ -103,13 +119,16 @@ def __init__(self, name: str, identifier: str = None): if self.id is None: # infer id from call site import re + line = inspect.stack()[1].code_context[0] pattern = r"\s*(?P\w+)\s*=\s*Benchmark\s*\(" mat = re.match(pattern, line) if mat is not None: self.id = mat.group("name") else: - raise SyntaxError(f"Benchmark id infer failed: only instantiation source line matching the regex \"{pattern}\" supports inferring.") + raise SyntaxError( + f'Benchmark id infer failed: only instantiation source line matching the regex "{pattern}" supports inferring.' + ) Benchmark.instances.append(self) @@ -127,10 +146,12 @@ def _get_or_create_case(self, subject: Subject) -> "TestCase": def setup(self, subject: Subject) -> Callable[[FunctionType], None]: """Function decorator. Set the setup routine of a test case.""" self._check() + def inner(func: FunctionType): case = self._get_or_create_case(subject) case.setup_src = extract_source(func) case.update() + return inner @overload @@ -144,8 +165,11 @@ def __call__(self, func: FunctionType) -> Subject: The setup routine must be set first if needed! """ + @overload - def __call__(self, *subjects: Subject, number: int = None) -> Callable[[FunctionType], None]: + def __call__( + self, *subjects: Subject, number: int = None + ) -> Callable[[FunctionType], None]: """ Decorate a function to add test cases for multiple subjects. The decorator returns None. @@ -175,8 +199,10 @@ def inner(i_func: FunctionType): # is "@benchmark_name()" (not "@benchmark_name") if func is None: if not all(c == "_" for c in i_func.__name__): - raise NameError("Function name must only consist of underscores " - "if subjects are specified via arguments.") + raise NameError( + "Function name must only consist of underscores " + "if subjects are specified via arguments." + ) for subject in subjects: case = self._get_or_create_case(subject) @@ -191,11 +217,13 @@ def inner(i_func: FunctionType): return subjects[0] return inner - def run(self, order_permutations: bool = False, min_runs_per_case: int = 100) -> list["TestCaseResult"]: + def run( + self, order_permutations: bool = False, min_runs_per_case: int = 100 + ) -> list["TestCaseResult"]: log(f"Benchmark - {self.id}:") with indent_log(): all_results = [] - results_by_subject = {s:[] for s in self.testcases.keys()} + results_by_subject = {s: [] for s in self.testcases.keys()} permutations: Sequence[Sequence[TestCase]] if order_permutations: # noinspection PyTypeChecker @@ -229,8 +257,11 @@ def run(self, order_permutations: bool = False, min_runs_per_case: int = 100) -> with NoGC: for rep in range(repeats): with temp_log(): - log(f"Sequence[{(i * repeats + rep + 1)}" - f"/{len(permutations) * repeats}]: ", False) + log( + f"Sequence[{(i * repeats + rep + 1)}" + f"/{len(permutations) * repeats}]: ", + False, + ) for case in cases: result = case.run() result.sequence = cases @@ -242,10 +273,12 @@ def run(self, order_permutations: bool = False, min_runs_per_case: int = 100) -> with indent_log(): for subject, results in results_by_subject.items(): times = tuple(r.runtime for r in results) - log(f"{subject.id}({len(results)} runs): " + log( + f"{subject.id}({len(results)} runs): " f"Min {min(times)/1e9:.3f}s, " f"Avg {sum(times)/len(times)/1e9:.3f}s, " - f"Max {max(times)/1e9:.3f}s") + f"Max {max(times)/1e9:.3f}s" + ) log() @@ -258,15 +291,12 @@ def to_json(self) -> Dict: return { "id": self.id, "name": self.name, - "testcases": tuple(c.to_json() for c in self.testcases.values()) + "testcases": tuple(c.to_json() for c in self.testcases.values()), } @classmethod def from_json(cls, data: Dict) -> Self: - inst = cls( - name=data["name"], - identifier=data["id"] - ) + inst = cls(name=data["name"], identifier=data["id"]) for case_dat in data["testcases"]: case = TestCase.from_json(case_dat) inst.testcases[case.subject] = case @@ -294,14 +324,18 @@ def _generate_benchmark_func(self): "for _i in _it:", )) src.extend(indent(self.main_src)) - src.extend(( - "return _timer() - _begin", - )) + src.extend(("return _timer() - _begin",)) src[1:] = indent(tuple(src[1:])) try: ls = {} - exec(compile_src(src, f""), {}, ls) + exec( + compile_src( + src, f"" + ), + {}, + ls, + ) self.benchmark_func = ls["_benchmark"] except Exception as e: raise RuntimeError("Benchmark function generation failed.") from e @@ -315,15 +349,11 @@ def update(self): name = f"" if self.main_src is not None: validate_source( - self.subject.setup + self.setup_src + self.main_src, - name % "full" + self.subject.setup + self.setup_src + self.main_src, name % "full" ) self._generate_benchmark_func() else: - validate_source( - self.subject.setup + self.setup_src, - name % "setup" - ) + validate_source(self.subject.setup + self.setup_src, name % "setup") def auto_number(self) -> int: assert self.is_auto_number @@ -341,22 +371,21 @@ def run(self) -> "TestCaseResult": assert self.number is not None, "Number of runs not set." datetime = utc_now() with NoGC: - return TestCaseResult(self, self.benchmark_func(self.number, TIMER), datetime) + return TestCaseResult( + self, self.benchmark_func(self.number, TIMER), datetime + ) def __repr__(self): return f"" def to_json(self, id_only=False) -> Dict: - data = { - "benchmark_id": self.benchmark.id, - "subject_id": self.subject.id - } + data = {"benchmark_id": self.benchmark.id, "subject_id": self.subject.id} if not id_only: data.update({ "setup_src": self.setup_src, "main_src": self.main_src, "is_auto_number": self.is_auto_number, - "number": self.number + "number": self.number, }) return data @@ -370,7 +399,9 @@ def from_json(cls, data: Dict, *, data_only=True, id_only=False) -> Self: if id_only: return benchmark.testcases[subject] else: - assert subject not in benchmark.testcases, f"TestCase of subject {subject} already exists in benchmark {benchmark}." + assert ( + subject not in benchmark.testcases + ), f"TestCase of subject {subject} already exists in benchmark {benchmark}." inst = cls(benchmark, subject) inst.setup_src = tuple(data["setup_src"]) inst.main_src = tuple(data["main_src"]) @@ -382,7 +413,13 @@ def from_json(cls, data: Dict, *, data_only=True, id_only=False) -> Self: class TestCaseResult: - def __init__(self, testcase: TestCase, runtime: int, datetime: DateTime, sequence: Sequence[TestCase] = None): + def __init__( + self, + testcase: TestCase, + runtime: int, + datetime: DateTime, + sequence: Sequence[TestCase] = None, + ): self.testcase = testcase self.runtime = runtime self.datetime = datetime @@ -396,8 +433,11 @@ def to_json(self) -> Dict: **self.testcase.to_json(id_only=True), "runtime": self.runtime, "timestamp": self.datetime.timestamp(), - "sequence": ([case.to_json(id_only=True) for case in self.sequence] - if self.sequence is not None else None) + "sequence": ( + [case.to_json(id_only=True) for case in self.sequence] + if self.sequence is not None + else None + ), } @classmethod @@ -406,15 +446,20 @@ def from_json(cls, data: Dict) -> Self: testcase=TestCase.from_json(data, id_only=True), runtime=data["runtime"], datetime=from_utc_stamp(data["timestamp"]), - sequence=(tuple(TestCase.from_json(d, id_only=True) for d in data["sequence"]) - if data["sequence"] is not None else None) + sequence=( + tuple(TestCase.from_json(d, id_only=True) for d in data["sequence"]) + if data["sequence"] is not None + else None + ), ) + class BenchmarkMetadata: def __init__(self): log("Getting metadata...") with indent_log(): import platform, cpuinfo, os + self.system = platform.system() log(f"System: {self.system}") self.ci = CI @@ -444,7 +489,7 @@ def to_json(self) -> Dict: "arch": self.arch, "cpu": self.cpu, "py_impl": self.py_impl, - "py_ver": self.py_ver + "py_ver": self.py_ver, } @classmethod @@ -458,6 +503,7 @@ def from_json(cls, data: Dict) -> Self: inst.py_ver = data["py_ver"] return inst + class BenchmarkResult: def __init__(self, datetime: DateTime, metadata: BenchmarkMetadata = None): self.datetime = datetime @@ -471,7 +517,9 @@ def add_results(self, *results: TestCaseResult): def benchmarks(self) -> Iterable[Benchmark]: return iter_identity(r.testcase.benchmark for r in self.raw_results) - def get_results(self, *, testcase: TestCase = None, subject: Subject = None) -> Generator[TestCaseResult, None, None]: + def get_results( + self, *, testcase: TestCase = None, subject: Subject = None + ) -> Generator[TestCaseResult, None, None]: for result in self.raw_results: if testcase is not None and result.testcase != testcase: continue @@ -483,14 +531,14 @@ def to_json(self) -> Dict: return { "timestamp": self.datetime.timestamp(), "metadata": self.metadata.to_json(), - "raw_results": tuple(r.to_json() for r in self.raw_results) + "raw_results": tuple(r.to_json() for r in self.raw_results), } @classmethod def from_json(cls, data: Dict) -> Self: inst = cls( datetime=from_utc_stamp(data["timestamp"]), - metadata=BenchmarkMetadata.from_json(data["metadata"]) + metadata=BenchmarkMetadata.from_json(data["metadata"]), ) inst.add_results(*(TestCaseResult.from_json(rd) for rd in data["raw_results"])) return inst @@ -519,11 +567,14 @@ def serialize(result: BenchmarkResult) -> Dict: return { "subjects": tuple(s.to_json() for s in Subject.instances), "benchmarks": tuple(b.to_json() for b in Benchmark.instances), - "result": result.to_json() + "result": result.to_json(), } + def deserialize(data: Dict) -> BenchmarkResult: - assert not Benchmark.instances or not Subject.instances, "Deserialization prohibited: environment not clean." + assert ( + not Benchmark.instances or not Subject.instances + ), "Deserialization prohibited: environment not clean." for d in data["subjects"]: Subject.from_json(d) @@ -531,10 +582,12 @@ def deserialize(data: Dict) -> BenchmarkResult: Benchmark.from_json(d) return BenchmarkResult.from_json(data["result"]) + def clear_env(): Benchmark.instances.clear() Subject.instances.clear() + def save_result(result: BenchmarkResult, file: Path): """Save the result and the current environment to a gzipped json file.""" file.parent.mkdir(parents=True, exist_ok=True) @@ -542,6 +595,7 @@ def save_result(result: BenchmarkResult, file: Path): with gzip.open(file, "wt", encoding="utf8") as f: json.dump(serialize(result), f, ensure_ascii=False) + def load_result(file: Path) -> BenchmarkResult: """Restore the result and the environment from a gzipped json file.""" log(f"Loading from {file}...") diff --git a/benchmark/benchmarks.py b/benchmark/benchmarks.py index b42ab97..2fd71ab 100644 --- a/benchmark/benchmarks.py +++ b/benchmark/benchmarks.py @@ -1,3 +1,4 @@ +# fmt: off # begin: noinspection PyUnresolvedReferences # begin: noinspection PyUnusedLocal # begin: noinspection PyUnboundLocalVariable diff --git a/benchmark/charting.py b/benchmark/charting.py index 514d654..7f67e7d 100644 --- a/benchmark/charting.py +++ b/benchmark/charting.py @@ -4,13 +4,15 @@ import matplotlib.pyplot as plt from matplotlib import patches, transforms + def get_runtime(result: TestCaseResult) -> int: return result.runtime + _TEXT = "#e6edf3" -class BackgroundFancyBboxPatch(patches.FancyBboxPatch): +class BackgroundFancyBboxPatch(patches.FancyBboxPatch): def __init__(self, **kwargs): super().__init__((0, 0), 1, 1, **kwargs) @@ -28,14 +30,24 @@ def draw(self, renderer): # Don't apply any transforms self._draw_paths_with_artist_properties( renderer, - [(self.get_path(), transforms.IdentityTransform(), - # Work around a bug in the PDF and SVG renderers, which - # do not draw the hatches if the facecolor is fully - # transparent, but do if it is None. - self._facecolor if self._facecolor[3] else None)]) - - -def chart(result: BenchmarkResult, baseline: Subject, *, fig_height=5, subtitles: Iterable[str] = ()) -> plt.Figure: + [( + self.get_path(), + transforms.IdentityTransform(), + # Work around a bug in the PDF and SVG renderers, which + # do not draw the hatches if the facecolor is fully + # transparent, but do if it is None. + self._facecolor if self._facecolor[3] else None, + )], + ) + + +def chart( + result: BenchmarkResult, + baseline: Subject, + *, + fig_height=5, + subtitles: Iterable[str] = (), +) -> plt.Figure: fig: plt.Figure ax: plt.Axes fig, ax = plt.subplots( @@ -48,28 +60,21 @@ def chart(result: BenchmarkResult, baseline: Subject, *, fig_height=5, subtitles # Background bg_edge = 2 bg = BackgroundFancyBboxPatch( - facecolor="#161b22", - linewidth=bg_edge, - edgecolor="#30363d", - figure=fig + facecolor="#161b22", linewidth=bg_edge, edgecolor="#30363d", figure=fig ) bg.set_boxstyle("round", pad=-bg_edge / 2, rounding_size=fig.dpi * 0.5) fig.patch = bg # Title - ax.set_title(f"Benchmark - {baseline.name} Implementation as Baseline\n" - f"Operations Per Second" - + (("\n" + "\n".join(subtitles)) if subtitles else ""), - color=_TEXT) + ax.set_title( + f"Benchmark - {baseline.name} Implementation as Baseline\n" + f"Operations Per Second" + (("\n" + "\n".join(subtitles)) if subtitles else ""), + color=_TEXT, + ) ax.margins(x=0.01) - ax.axhline( - y=1, - color=baseline.color, - linestyle="--", - label="Baseline" - ) + ax.axhline(y=1, color=baseline.color, linestyle="--", label="Baseline") highest = 0 x_pos = 0 @@ -78,22 +83,19 @@ def chart(result: BenchmarkResult, baseline: Subject, *, fig_height=5, subtitles subject_ops_per_sec: dict[Subject, float] = {} for testcase in benchmark.testcases.values(): min_runtime = min(r.runtime for r in result.get_results(testcase=testcase)) - subject_ops_per_sec[testcase.subject] = testcase.number / (min_runtime / 1e9) + subject_ops_per_sec[testcase.subject] = testcase.number / ( + min_runtime / 1e9 + ) begin_pos = x_pos for subject, ops_per_sec in sorted( - subject_ops_per_sec.items(), - key=lambda pair: pair[0].sort + subject_ops_per_sec.items(), key=lambda pair: pair[0].sort ): height = ops_per_sec / (subject_ops_per_sec[baseline]) highest = max(height, highest) rect = ax.bar( - x=x_pos, - height=height, - width=1, - label=subject.name, - color=subject.color + x=x_pos, height=height, width=1, label=subject.name, color=subject.color ) ax.bar_label( rect, @@ -101,7 +103,7 @@ def chart(result: BenchmarkResult, baseline: Subject, *, fig_height=5, subtitles labels=[numerize(ops_per_sec, decimals=1)], fontsize=10, rotation=90, - color=_TEXT + color=_TEXT, ) x_pos += 1 @@ -112,11 +114,7 @@ def chart(result: BenchmarkResult, baseline: Subject, *, fig_height=5, subtitles ax.margins(y=18 / highest / fig_height) - ax.set_xticks( - *zip(*x_ticks), - rotation=15, - color=_TEXT - ) + ax.set_xticks(*zip(*x_ticks), rotation=15, color=_TEXT) [s.set_color("#ced0d6") for s in ax.spines.values()] ax.yaxis.set_major_formatter(lambda x, pos: f"{x}x") @@ -127,13 +125,14 @@ def chart(result: BenchmarkResult, baseline: Subject, *, fig_height=5, subtitles handles, labels = ax.get_legend_handles_labels() temp = {k: v for k, v in zip(labels, handles)} legend = ax.legend( - temp.values(), temp.keys(), + temp.values(), + temp.keys(), loc="upper left", labelcolor=_TEXT, facecolor="#3d3f42", edgecolor="#3d3f42", - fancybox = True, - framealpha = 0.5 + fancybox=True, + framealpha=0.5, ) legend.legendPatch.set_boxstyle("round", rounding_size=0.5) diff --git a/benchmark/gen_charts.py b/benchmark/gen_charts.py index c6fb824..eb07daa 100644 --- a/benchmark/gen_charts.py +++ b/benchmark/gen_charts.py @@ -11,7 +11,9 @@ files = [Path(f) for f in os.listdir(RESULTS) if f.endswith(".dat")] # Latest to earliest -files.sort(reverse=True, key=lambda f: datetime.datetime.strptime(f.stem, "%Y%m%d_%H-%M-%S")) +files.sort( + reverse=True, key=lambda f: datetime.datetime.strptime(f.stem, "%Y%m%d_%H-%M-%S") +) for file in files: log(f"{file}:") @@ -27,9 +29,8 @@ f"{result.metadata.py_impl} {result.metadata.py_ver} · " f"{result.metadata.system} · " f"{result.metadata.cpu}" - + (" (GitHub Actions)" if result.metadata.ci else "") - , - ) + + (" (GitHub Actions)" if result.metadata.ci else ""), + ), ) chart.savefig(CHARTS / file.with_suffix(".svg")) clear_env() @@ -45,4 +46,6 @@ f.write("---\n") f.write("\n") for file in files: - f.write(f"[![{file.stem}](./{file.with_suffix(".svg")})](./{file.with_suffix(".svg")})\n") + f.write( + f"[![{file.stem}](./{file.with_suffix(".svg")})](./{file.with_suffix(".svg")})\n" + ) diff --git a/benchmark/pure_python_impl.py b/benchmark/pure_python_impl.py index 2d2563d..0932033 100644 --- a/benchmark/pure_python_impl.py +++ b/benchmark/pure_python_impl.py @@ -21,17 +21,9 @@ def __init__(self, x: float, y: float, z: float): def __add__(self, other: Union["Vec3", float]): if isinstance(other, Vec3): - return Vec3( - self.x + other.x, - self.y + other.y, - self.z + other.z - ) + return Vec3(self.x + other.x, self.y + other.y, self.z + other.z) elif isinstance(other, float | int): - return Vec3( - self.x + other, - self.y + other, - self.z + other - ) + return Vec3(self.x + other, self.y + other, self.z + other) else: raise TypeError() @@ -59,7 +51,7 @@ def __xor__(self, other: "Vec3"): return Vec3( self.y * other.z - self.z * other.y, self.z * other.x - self.x * other.z, - self.x * other.y - self.y * other.x + self.x * other.y - self.y * other.x, ) else: raise TypeError() diff --git a/benchmark/utils.py b/benchmark/utils.py index a3cc43c..bb51f87 100644 --- a/benchmark/utils.py +++ b/benchmark/utils.py @@ -3,6 +3,7 @@ import inspect import itertools from contextlib import contextmanager + # noinspection PyPep8Naming from datetime import datetime as DateTime from datetime import timezone @@ -32,25 +33,29 @@ "indent_log", "temp_log", "iter_identity", - "auto_number_series" + "auto_number_series", ) SourceLines = Tuple[str, ...] + def auto_number_series() -> Generator[int, None, None]: for expo in itertools.count(2): # E12 series for value in (10, 12, 15, 18, 22, 27, 33, 39, 47, 56, 68, 82): - yield value * 10 ** expo + yield value * 10**expo + def indent(src: SourceLines) -> SourceLines: - return tuple(" "*4 + line for line in src) + return tuple(" " * 4 + line for line in src) + def dedent(src: SourceLines) -> SourceLines: for i in itertools.count(): - if any(line[i:i+1] != " " for line in src): + if any(line[i : i + 1] != " " for line in src): return tuple(line[i:] for line in src) + def extract_source(func: FunctionType) -> SourceLines: """ Extract the source code of the body of func. @@ -58,32 +63,38 @@ def extract_source(func: FunctionType) -> SourceLines: """ src = inspect.getsource(func).splitlines() src = list(dedent(src)) - if src[0][0] == "@": # delete decorator + if src[0][0] == "@": # delete decorator del src[0] - del src[0] # delete function header + del src[0] # delete function header return dedent(src) + def compile_src(src: SourceLines, name: str = ""): return compile("\n".join(src), name, "exec", optimize=2) + def validate_source(src: SourceLines, name: str): try: exec(compile_src(src, name)) except Exception as e: raise RuntimeError("Code validation failed.") from e + def extract_and_validate_source(func: FunctionType, name: str): src = extract_source(func) name += f"({func.__name__} in {func.__module__})" validate_source(src, name) return src + def utc_now() -> DateTime: return DateTime.now(timezone.utc) + def from_utc_stamp(stamp: float) -> DateTime: return DateTime.fromtimestamp(stamp, timezone.utc) + class _NoGC: def __init__(self): self._enter_depth = 0 @@ -98,16 +109,22 @@ def __exit__(self, *_): if self._enter_depth == 0: gc.enable() + NoGC = _NoGC() T = TypeVar("T") + + def _instance_from_id(cls: Type[T], identifier: str) -> Optional[T]: for inst in cls.instances: if inst.id == identifier: return inst return None + _log_indent = 0 + + @contextmanager def indent_log(n=1): global _log_indent @@ -116,8 +133,11 @@ def indent_log(n=1): _log_indent -= n assert _log_indent >= 0 + _log_chrs_stack = [] _log_temp_log = False + + @contextmanager def temp_log(disable=False): if disable: @@ -134,7 +154,10 @@ def temp_log(disable=False): _should_indent = old_should_indent _log_temp_log = False + _should_indent = True + + def log(msg="", new_line=True, color: str = None): if CI and _log_temp_log: return @@ -153,6 +176,7 @@ def log(msg="", new_line=True, color: str = None): _should_indent = new_line + def iter_identity(it: Iterable[T]) -> Generator[T, None, None]: got = set() for item in it: diff --git a/codegen/codegen_helper.py b/codegen/codegen_helper.py index 7f07f9e..51bce93 100644 --- a/codegen/codegen_helper.py +++ b/codegen/codegen_helper.py @@ -89,9 +89,16 @@ def apply_params(in_line: str) -> str: if ":" in line: gen_pos = line.index("#:") prefix = line[:gen_pos] - expr = line[gen_pos + 7:] + expr = line[gen_pos + 7 :] try: - result = eval(expr, globals() if _globals is None else globals().copy().update(_globals)) + result = eval( + expr, + ( + globals() + if _globals is None + else globals().copy().update(_globals) + ), + ) generated = True except Exception as e: if isinstance(e, AssertionError): @@ -138,6 +145,7 @@ def __init__(self, name: str, c_params: tuple[str], c_ret: str): case _: self.ret = c_ret + class _Overload: def __init__(self, name: str): self.name = name @@ -146,8 +154,9 @@ def __init__(self, name: str): def add(self, func: _Func): if len(self._funcs) > 0: - assert (Self in self._funcs[0].params) == (Self in func.params), \ - "function/method mismatch" + assert (Self in self._funcs[0].params) == ( + Self in func.params + ), "function/method mismatch" self._funcs.append(func) self._possible_param_types_cache = None @@ -181,10 +190,12 @@ def possible_param_types(self) -> Sequence[Sequence]: if p not in self._possible_param_types_cache[i]: self._possible_param_types_cache[i].append(p) - for optional_param in self._possible_param_types_cache[self.min_params:]: + for optional_param in self._possible_param_types_cache[self.min_params :]: optional_param.append(None) - self._possible_param_types_cache = tuple(tuple(ts) for ts in self._possible_param_types_cache) + self._possible_param_types_cache = tuple( + tuple(ts) for ts in self._possible_param_types_cache + ) return self._possible_param_types_cache @@ -204,19 +215,19 @@ def param_names(self) -> tuple[str]: @staticmethod def _type_check_expression(*types): - if float in types and int in types: # py_float + if float in types and int in types: # py_float assert len(types) == 2, "py_float type should be a discrete branch!" # Insignificant optimization # return "(PyFloat_CheckExact({value}) or PyLong_CheckExact({value}))" return "(PyFloat_Check({value}) or PyLong_Check({value}))" - elif int in types: # py_int + elif int in types: # py_int assert len(types) == 1, "py_int type should be a discrete branch!" # Insignificant optimization # return "PyLong_CheckExact({value})" return "PyLong_Check({value})" else: out = "" - for i,t in enumerate(types): + for i, t in enumerate(types): if type(t) is str: out += f"isinstance({{value}}, {t})" elif t is None: @@ -269,29 +280,35 @@ def _gen_type_no_match_exception(self, param_types: tuple[tuple]) -> str: type_strs.append(self._type_str(pts[0])) else: type_strs.append(f"({' | '.join(map(self._type_str, pts))})") - text = (f"raise TypeError(\"No matching overload function for parameter types: " - f"{', '.join(type_strs)}") + text = ( + f'raise TypeError("No matching overload function for parameter types: ' + f"{', '.join(type_strs)}" + ) if len(param_types) != self.max_params: text += ", ..." - text += "\")" + text += '")' return text def _gen_type_cast(self, var_name: str, var_types: tuple): multiple_type_err = "A branch should have only one general(cast-able) type!" - if float in var_types and int in var_types: # py_float + if float in var_types and int in var_types: # py_float assert len(var_types) == 2, multiple_type_err # Insignificant optimization # return f"PyFloat_AS_DOUBLE({var_name}) if PyFloat_CheckExact({var_name}) else PyLong_AsDouble({var_name})" return f"PyFloat_AsDouble({var_name})" - elif int in var_types: # py_int + elif int in var_types: # py_int assert len(var_types) == 1, multiple_type_err return f"PyLong_AsLongLong({var_name})" else: assert len(var_types) == 1, multiple_type_err - assert isinstance(var_types[0], str), f"Don't know how to cast to {var_types[0]}." + assert isinstance( + var_types[0], str + ), f"Don't know how to cast to {var_types[0]}." return f"<{var_types[0]}> {var_name}" - def _gen_dispatch_tree(self, params_types: tuple = None) -> tuple[Sequence[str], bool]: + def _gen_dispatch_tree( + self, params_types: tuple = None + ) -> tuple[Sequence[str], bool]: """Generate overload dispatch tree recursively.""" if params_types is None: params_types = ((Self,),) if self._funcs[0].params[0] is Self else () @@ -302,12 +319,16 @@ def _gen_dispatch_tree(self, params_types: tuple = None) -> tuple[Sequence[str], if len(funcs) == 0: return [self._gen_type_no_match_exception(params_types)], False if len(funcs) > 1: - assert False, f"Multiple matching functions for: {', '.join(str(p) for p in params_first_types)}" + assert ( + False + ), f"Multiple matching functions for: {', '.join(str(p) for p in params_first_types)}" func = funcs[0] - out = [f"{'self.' if self.is_method else ''}" - f"{func.name}" - f"({', '.join(self._gen_type_cast(pn, pts) for pts,pn in zip(params_types, self.param_names) if pn != 'self' and None not in pts)})"] + out = [ + f"{'self.' if self.is_method else ''}" + f"{func.name}" + f"({', '.join(self._gen_type_cast(pn, pts) for pts,pn in zip(params_types, self.param_names) if pn != 'self' and None not in pts)})" + ] if func.ret is None: out.append("return") else: @@ -326,7 +347,9 @@ def _gen_dispatch_tree(self, params_types: tuple = None) -> tuple[Sequence[str], if len(matches) == 0: continue for branch in branches: - if set(matches) == set(self._func_from_params(params_first_types + (branch[0],))): + if set(matches) == set( + self._func_from_params(params_first_types + (branch[0],)) + ): branch.append(t) break else: @@ -337,8 +360,10 @@ def _gen_dispatch_tree(self, params_types: tuple = None) -> tuple[Sequence[str], any_hit = False for i, t in enumerate(branches): - out.append(f"{'if' if i == 0 else 'elif'} " - f"{self._type_check_expression(*t).format(value=self.param_names[len(params_types)])}:") + out.append( + f"{'if' if i == 0 else 'elif'} " + f"{self._type_check_expression(*t).format(value=self.param_names[len(params_types)])}:" + ) branch_params = params_types + (t,) dt, hit = self._gen_dispatch_tree(branch_params) if hit: @@ -347,12 +372,16 @@ def _gen_dispatch_tree(self, params_types: tuple = None) -> tuple[Sequence[str], for l in dt: out.append(f" {l}") else: - out.append(" " + self._gen_type_no_match_exception(branch_params)) + out.append( + " " + self._gen_type_no_match_exception(branch_params) + ) out.append("else:") expected_type_strs = [self._type_str(t) for t in possible_types] expected_types_str = " | ".join(expected_type_strs) - out.append(f" raise TypeError(f\"The {len(params_first_types) + 1}th parameter expected {expected_types_str}, " - f"got {{{self.param_names[len(params_first_types)]}}}\")") + out.append( + f' raise TypeError(f"The {len(params_first_types) + 1}th parameter expected {expected_types_str}, ' + f'got {{{self.param_names[len(params_first_types)]}}}")' + ) return out, any_hit @@ -370,11 +399,14 @@ def gen_dispatcher(self) -> Sequence[str]: ret_types = list(set(_Overload._type_str(f.ret) for f in self._funcs)) - lines = [f"def {self.name}({', '.join(('object ' if i > 0 else '') + p for i,p in enumerate(param_strs))}, /) -> {' | '.join(ret_types)}:"] + lines = [ + f"def {self.name}({', '.join(('object ' if i > 0 else '') + p for i,p in enumerate(param_strs))}, /) -> {' | '.join(ret_types)}:" + ] disp_lines, _ = self._gen_dispatch_tree() lines += [f" {dl}" for dl in disp_lines] return lines + def process_overloads(file: str) -> str: import regex @@ -398,7 +430,7 @@ def process_overloads(file: str) -> str: r"\(" r"\s*(?:(?P\w+\s+\w+|self)(?:\s*,\s*(?P\w+\s+\w+))*)?\s*" r"\)", - line + line, ) assert m is not None, line name = m.group("name") @@ -410,7 +442,7 @@ def process_overloads(file: str) -> str: overload = overloads[name] new_name = f"_{name}_{len(overload):d}" name_span = m.span("name") - lines[i] = line[:name_span[0]] + new_name + line[name_span[1]:] + lines[i] = line[: name_span[0]] + new_name + line[name_span[1] :] overload.add(_Func(new_name, c_params, c_ret)) lines = [line for line in lines if "" not in line] @@ -418,14 +450,17 @@ def process_overloads(file: str) -> str: while True: for i, line in enumerate(lines): if ":" in line: - m = regex.match(r"(?P\s*)#(?:\s|#)*:(?P\w+)", line) + m = regex.match( + r"(?P\s*)#(?:\s|#)*:(?P\w+)", + line, + ) assert m is not None prefix = m.group("prefix") name = m.group("name") assert name in overloads disp_lines = overloads[name].gen_dispatcher() disp_lines = [prefix + disp_line for disp_line in disp_lines] - lines[i:i + 1] = disp_lines + lines[i : i + 1] = disp_lines break else: break @@ -433,17 +468,24 @@ def process_overloads(file: str) -> str: return "".join(f"{line}\n" for line in lines) -def step_generate(template_file: str, output_file: str = None, write_file: bool = False, params: dict = None, _globals: dict = None, overload: bool = False): +def step_generate( + template_file: str, + output_file: str = None, + write_file: bool = False, + params: dict = None, + _globals: dict = None, + overload: bool = False, +): print(f"Step Generate: {template_file}") if output_file is None: output_file = template_file import os + if not os.path.exists("output"): os.mkdir("output") - template = open(f"templates/{template_file}", encoding="utf8").read() t = time.perf_counter() if _globals is not None: @@ -469,10 +511,13 @@ def step_generate(template_file: str, output_file: str = None, write_file: bool def step_gen_stub(source_file: str, output_file: str): import stub_generator + source = open(f"output/{source_file}", encoding="utf8").read() t = time.perf_counter() result = stub_generator.gen_stub(source) - print(f"Step Gen Stub: {source_file} -> {output_file} completed in {time.perf_counter() - t:.3f}s") + print( + f"Step Gen Stub: {source_file} -> {output_file} completed in {time.perf_counter() - t:.3f}s" + ) with open(f"output/{output_file}", "w", encoding="utf8") as output: output.write(result) @@ -482,17 +527,16 @@ def step_cythonize(file: str): import subprocess print("########## Cythonize ##########") - proc = subprocess.Popen(( - "cythonize.exe", - "-a", - "-i", - f"output/{file}.pyx" - ), stdout=sys.stdout, stderr=sys.stderr) + proc = subprocess.Popen( + ("cythonize.exe", "-a", "-i", f"output/{file}.pyx"), + stdout=sys.stdout, + stderr=sys.stderr, + ) t = time.perf_counter() proc.wait() print( f"Cythonize finished in {time.perf_counter() - t}s with exit code {proc.returncode}", - file=sys.stdout if proc.returncode == 0 else sys.stderr + file=sys.stdout if proc.returncode == 0 else sys.stderr, ) @@ -506,4 +550,3 @@ def step_move_to_dest(final_dest: str, file_prefix: str, file_suffix: str): dest = os.path.join(final_dest, file) shutil.copy(path, dest) print(f"Coping {file} to {dest}") - diff --git a/codegen/gen_all.py b/codegen/gen_all.py index e4e8ad9..9695bc3 100644 --- a/codegen/gen_all.py +++ b/codegen/gen_all.py @@ -3,24 +3,35 @@ def main(): import vector_codegen + codegen.step_generate( "_spatium.pyx", write_file=True, - _globals={vector_codegen.__name__: vector_codegen}) + _globals={vector_codegen.__name__: vector_codegen}, + ) codegen.step_gen_stub("_spatium.pyx", "_spatium.pyi") import sys + if "--install" in sys.argv: - print("#"*15 + " Install " + "#"*15) + print("#" * 15 + " Install " + "#" * 15) codegen.step_move_to_dest("../src/spatium/", "_spatium", ".pyx") codegen.step_move_to_dest("../src/spatium/", "_spatium", ".pyi") import sys import subprocess - print("#"*15 + " pip uninstall spatium -y " + "#"*15) - subprocess.call(f"{sys.executable} -m pip uninstall spatium -y", stdout=sys.stdout) - print("#"*15 + " pip install -v -v -v .. " + "#"*15) - subprocess.call(f"{sys.executable} -m pip install -v -v -v ..", stdout=sys.stdout, stderr=sys.stdout) -if __name__ == '__main__': + print("#" * 15 + " pip uninstall spatium -y " + "#" * 15) + subprocess.call( + f"{sys.executable} -m pip uninstall spatium -y", stdout=sys.stdout + ) + print("#" * 15 + " pip install -v -v -v .. " + "#" * 15) + subprocess.call( + f"{sys.executable} -m pip install -v -v -v ..", + stdout=sys.stdout, + stderr=sys.stdout, + ) + + +if __name__ == "__main__": main() diff --git a/codegen/stub_generator.py b/codegen/stub_generator.py index 9ee28de..31e7dca 100644 --- a/codegen/stub_generator.py +++ b/codegen/stub_generator.py @@ -20,6 +20,7 @@ def convert_type(org: str) -> str: types.append(f'"{raw}"') return types[0] if len(types) == 1 else f"Union[{', '.join(types)}]" + class StubProperty: def __init__(self, name: str, ptype: str, mutable: bool = True): self.name = name @@ -29,10 +30,7 @@ def __init__(self, name: str, ptype: str, mutable: bool = True): self.setter_doc = [] def stub(self) -> list[str]: - stub = [ - "@property", - f"def {self.name}(self) -> {self.type}:" - ] + stub = ["@property", f"def {self.name}(self) -> {self.type}:"] if self.getter_doc: stub.extend(self.getter_doc) stub.append(" ...") @@ -53,7 +51,13 @@ def stub(self) -> list[str]: class StubMethod: class Param: - def __init__(self, name: str, ptype: str, default: Optional[str] = None, const_mapping: dict[str, str] = None): + def __init__( + self, + name: str, + ptype: str, + default: Optional[str] = None, + const_mapping: dict[str, str] = None, + ): self.name = name self.ptype = ptype self.default = default @@ -68,7 +72,14 @@ def stub(self) -> str: stub += f" = {default}" return stub - def __init__(self, name: str, rtype: str, params: Sequence[Param | str], is_cdef: bool, is_static: bool): + def __init__( + self, + name: str, + rtype: str, + params: Sequence[Param | str], + is_cdef: bool, + is_static: bool, + ): self.name = name self.rtype = rtype self.params = params @@ -88,7 +99,9 @@ def stub(self, name_override: Optional[str] = None) -> list[str]: else: param_list.append(p.stub()) - stub.append(f"def {name_override or self.name}({', '.join(param_list)}) -> {self.rtype}:") + stub.append( + f"def {name_override or self.name}({', '.join(param_list)}) -> {self.rtype}:" + ) if self.docstring: stub.extend(self.docstring) stub.append(" ...") @@ -97,6 +110,7 @@ def stub(self, name_override: Optional[str] = None) -> list[str]: return stub + class StubClass: def __init__(self, name: str): self.name = name @@ -112,7 +126,7 @@ def indent(src: list[str]) -> list[str]: stub = [ "# noinspection SpellCheckingInspection,GrazieInspection", - f"class {self.name}:" + f"class {self.name}:", ] if self.docstring: stub.extend(indent(self.docstring)) @@ -135,7 +149,9 @@ def indent(src: list[str]) -> list[str]: del methods[ol_name] assert ol_method.is_cdef stub.extend(indent(["@overload"])) - stub.extend(indent(ol_method.stub(name_override=method.name))) + stub.extend( + indent(ol_method.stub(name_override=method.name)) + ) else: break # Not Overloaded @@ -153,9 +169,10 @@ def gen_stub(source: str) -> str: in_docstring = False docstring_dest: Optional[list[str]] = None + def add_docstring_line(doc_line: str): assert docstring_dest is not None - docstring_dest.append(doc_line[4:]) # remove one indent level + docstring_dest.append(doc_line[4:]) # remove one indent level print("gen_stub: reading source...") source_lines = source.splitlines(keepends=False) @@ -181,7 +198,10 @@ def add_docstring_line(doc_line: str): docstring_dest = current_class.docstring decorators.clear() # Property - elif m := regex.match(r"\s+cdef\s+public\s+(?P\w+)\s+(?P\w+)(?:\s*,\s*(?P\w+))*", line): + elif m := regex.match( + r"\s+cdef\s+public\s+(?P\w+)\s+(?P\w+)(?:\s*,\s*(?P\w+))*", + line, + ): ptype = convert_type(m.group("type")) for pname in m.captures("names"): current_class.properties[pname] = StubProperty(pname, ptype) @@ -189,32 +209,32 @@ def add_docstring_line(doc_line: str): decorators.clear() # Method (including @property) elif ( - (cdef_m := regex.match( # cdef methods can not be static, self always present - r"\s+cdef\s+" - r"(?:inline\s+)?" - r"(?P\w+)?\s+" - r"(?P\w+)\s*" - r"\(\s*" - r"(?:self\s*)?" - r"(?:,\s*(?P\w+\s+\w+)\s*)*" - r"\)\s*" - r"(?:noexcept)?\s*" - r":", - line - )) - or - (def_m := regex.match( - r"\s+def\s+" - r"(?P\w+)\s*" - r"\(\s*" - r"(?:(?:self\s*)|(?&_param))?" - r"(?:,\s*(?P<_param>(?P\w+\s+\w+(?:\s*=\s*[^,)]+)?)|(?P/))\s*)*" - r"\)\s*" - r"(?:->\s*(?P[^:]+))?\s*" - r":", - line - )) - ): + cdef_m := regex.match( # cdef methods can not be static, self always present + r"\s+cdef\s+" + r"(?:inline\s+)?" + r"(?P\w+)?\s+" + r"(?P\w+)\s*" + r"\(\s*" + r"(?:self\s*)?" + r"(?:,\s*(?P\w+\s+\w+)\s*)*" + r"\)\s*" + r"(?:noexcept)?\s*" + r":", + line, + ) + ) or ( + def_m := regex.match( + r"\s+def\s+" + r"(?P\w+)\s*" + r"\(\s*" + r"(?:(?:self\s*)|(?&_param))?" + r"(?:,\s*(?P<_param>(?P\w+\s+\w+(?:\s*=\s*[^,)]+)?)|(?P/))\s*)*" + r"\)\s*" + r"(?:->\s*(?P[^:]+))?\s*" + r":", + line, + ) + ): m = cdef_m or def_m is_cdef = cdef_m is not None name = m.group("name") @@ -223,7 +243,9 @@ def add_docstring_line(doc_line: str): # property getter if "property" in decorators: assert name not in current_class.properties - current_class.properties[name] = prop = StubProperty(name, convert_type(m.group("return")), mutable=False) + current_class.properties[name] = prop = StubProperty( + name, convert_type(m.group("return")), mutable=False + ) docstring_dest = prop.getter_doc # proeprty setter elif setter_dec := [d for d in decorators if d.endswith(".setter")]: @@ -247,18 +269,30 @@ def add_docstring_line(doc_line: str): # default values are not supported for cdef methods yet pm = regex.fullmatch(r"(?P\w+)\s+(?P\w+)", param) assert pm is not None - params.append(StubMethod.Param(pm.group("name"), convert_type(pm.group("type")), const_mapping=const_mapping)) + params.append( + StubMethod.Param( + pm.group("name"), + convert_type(pm.group("type")), + const_mapping=const_mapping, + ) + ) else: - pm = regex.fullmatch(r"(?P\w+)\s+(?P\w+)(?:\s*=\s*(?P[^,]+))?", param) + pm = regex.fullmatch( + r"(?P\w+)\s+(?P\w+)(?:\s*=\s*(?P[^,]+))?", + param, + ) assert pm is not None - params.append(StubMethod.Param(pm.group("name"), convert_type(pm.group("type")), pm.group("default"), const_mapping=const_mapping)) + params.append( + StubMethod.Param( + pm.group("name"), + convert_type(pm.group("type")), + pm.group("default"), + const_mapping=const_mapping, + ) + ) current_class.methods[name] = method = StubMethod( - name, - rtype, - params, - is_cdef, - is_static="staticmethod" in decorators + name, rtype, params, is_cdef, is_static="staticmethod" in decorators ) docstring_dest = method.docstring @@ -280,7 +314,7 @@ def add_docstring_line(doc_line: str): out_lines = [ "# noinspection PyUnresolvedReferences", "from typing import overload, Self, Any, Union", - "" + "", ] for cls in classes: print(f"gen_stub: generating class {cls.name}") @@ -289,7 +323,7 @@ def add_docstring_line(doc_line: str): return "".join(f"{line}\n" for line in out_lines) -if __name__ == '__main__': +if __name__ == "__main__": result = gen_stub(open("output/_spatium.pyx", encoding="utf8").read()) with open("output/_spatium.pyi", "w") as f: f.write(result) diff --git a/codegen/vector_codegen.py b/codegen/vector_codegen.py index e646d0f..7a882da 100644 --- a/codegen/vector_codegen.py +++ b/codegen/vector_codegen.py @@ -65,7 +65,9 @@ def gen_combination_constructors(dims: int, vtype: Type) -> str: # Generate params self_dim = 0 for in_dims in combination: - param_types.append(f"{get_c_type(vtype) if in_dims == 1 else get_vec_class_name(in_dims, vtype)}") + param_types.append( + f"{get_c_type(vtype) if in_dims == 1 else get_vec_class_name(in_dims, vtype)}" + ) param_name = "" for in_dim in range(in_dims): param_name += DIMS[self_dim] @@ -79,12 +81,16 @@ def gen_combination_constructors(dims: int, vtype: Type) -> str: if in_dims == 1: assigns.append(f"self.{DIMS[self_dim]} = {param_name}") else: - assigns.append(f"self.{DIMS[self_dim]} = {param_name}.{DIMS[in_dim]}") + assigns.append( + f"self.{DIMS[self_dim]} = {param_name}.{DIMS[in_dim]}" + ) self_dim += 1 func = "#\n" - func += (f"cdef inline void __init__(self, " - f"{', '.join(f'{t} {n}' for t, n in zip(param_types, param_names))}) noexcept:\n") + func += ( + f"cdef inline void __init__(self, " + f"{', '.join(f'{t} {n}' for t, n in zip(param_types, param_names))}) noexcept:\n" + ) docstring = f"Create a {dims}D vector from " for i, param_dims in enumerate(combination): @@ -156,7 +162,9 @@ def _gen_swizzle_get(swizzle: str, vtype: Type) -> str: def _gen_swizzle_set(swizzle: str, vtype: Type): out = f"@{swizzle}.setter\n" - out += f"def {swizzle}(self, {get_vec_class_name(len(swizzle), vtype)} vec) -> None:\n" + out += ( + f"def {swizzle}(self, {get_vec_class_name(len(swizzle), vtype)} vec) -> None:\n" + ) docstring = f"Set the value of the {', '.join(swizzle.upper())}" docstring = docstring[:-3] + " and " + docstring[-1] @@ -196,10 +204,7 @@ def gen_swizzle_properties(dims: int, vtype: Type) -> str: def gen_for_each_dim(template: str, dims: int, join="\n") -> str: out = "" for i, dim in enumerate(range(dims)): - out += template.format_map({ - "dim": DIMS[dim], - "index": i - }) + out += template.format_map({"dim": DIMS[dim], "index": i}) if dim != dims - 1: out += join return out @@ -216,8 +221,13 @@ def gen_repr(dims: int, vtype: Type) -> str: out += ')"' return out + def gen_common_binary_and_inplace_op(op: str, name: str, readable_name: str) -> str: - return from_template(open("templates/common_binary_and_inplace_op.pyx").read(), {"Op": op, "OpName": name, "OpReadableName": readable_name}) + return from_template( + open("templates/common_binary_and_inplace_op.pyx").read(), + {"Op": op, "OpName": name, "OpReadableName": readable_name}, + ) + def gen_item_op(dims: int, op: str) -> str: out = "" @@ -226,9 +236,10 @@ def gen_item_op(dims: int, op: str) -> str: out += f"key == {dim}:\n" out += f" {op.format(dim=DIMS[dim])}\n" out += "else:\n" - out += " raise KeyError(f\"_VecClassName_ index out of range: {key}\")" + out += ' raise KeyError(f"_VecClassName_ index out of range: {key}")' return out + def gen_iterator_next(dims: int) -> str: out = "" for dim in range(dims): diff --git a/pyproject.toml b/pyproject.toml index eef79ae..4db87d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,3 +63,9 @@ archs = ["x86_64", "i686"] [tool.cibuildwheel.macos] archs = ["x86_64", "arm64"] + +[tool.black] +preview = true +enable-unstable-feature = [ + "hug_parens_with_braces_and_square_brackets" +] diff --git a/setup.py b/setup.py index 827a79c..2499a2d 100644 --- a/setup.py +++ b/setup.py @@ -13,8 +13,7 @@ ), ], packages=find_packages( - where="src", - exclude=["tests", "spatium/*.c", "spatium/*.cpp"] + where="src", exclude=["tests", "spatium/*.c", "spatium/*.cpp"] ), - package_dir={'': 'src'}, + package_dir={"": "src"}, ) diff --git a/tests/test_transform_2d.py b/tests/test_transform_2d.py index 930b8df..490f6c2 100644 --- a/tests/test_transform_2d.py +++ b/tests/test_transform_2d.py @@ -10,6 +10,7 @@ def test_normal_constructor_and_get_components(): assert t.y == Vec2(3, 4) assert t.origin == Vec2(5, 6) + def test_comparison(): def gen_diff_at(index, n): return Transform2D(*[n if i == index else 0.5 for i in range(6)]) @@ -20,6 +21,7 @@ def gen_diff_at(index, n): result = gen_diff_at(a, num) == gen_diff_at(b, num) assert result if a == b else not result + def test_set_components(): t = Transform2D() t.x = Vec2(1, 2) @@ -27,11 +29,16 @@ def test_set_components(): t.origin = Vec2(5, 6) assert t == Transform2D(1, 2, 3, 4, 5, 6) + def test_empty_constructor(): assert Transform2D() == Transform2D(1, 0, 0, 1, 0, 0) + def test_component_constructor(): - assert Transform2D(Vec2(1, 2), Vec2(3, 4), Vec2(5, 6)) == Transform2D(1, 2, 3, 4, 5, 6) + assert Transform2D(Vec2(1, 2), Vec2(3, 4), Vec2(5, 6)) == Transform2D( + 1, 2, 3, 4, 5, 6 + ) + def test_copy_constructor(): t = Transform2D(*range(6)) @@ -39,19 +46,23 @@ def test_copy_constructor(): assert id(t) != id(tc) assert t == tc + def test_translation_constructor(): t = Transform2D.translating(Vec2(1, 2)) assert t == Transform2D(1, 0, 0, 1, 1, 2) + def test_rotating_constructor(): r = 1.23 t = Transform2D.rotating(r) assert t.is_close(Transform2D(cos(r), sin(r), -sin(r), cos(r), 0, 0), 1e-7) + def test_scaling_constructor(): t = Transform2D.scaling(Vec2(1, 2)) assert t == Transform2D(1, 0, 0, 2, 0, 0) + def test_vector_xform(): t = Transform2D(1, 2, 3, 4, 5, 6) v = Vec2(1, 2) @@ -59,11 +70,13 @@ def test_vector_xform(): assert (t * v).is_close(ans) assert t(v).is_close(ans) + def test_vector_inverse_xform(): print(Vec2(1, 2) * Transform2D(1, 2, 3, 4, 5, 6)) print((~Transform2D(*range(1, 7))) * Vec2(1, 2)) assert (Vec2(1, 2) * Transform2D(1, 2, 3, 4, 5, 6)).is_close(Vec2(-12, -28)) + def test_matmul(): t1 = Transform2D(1, 2, 3, 4, 5, 6) t2 = Transform2D(3, 1, 2, 6, 5, 4) @@ -71,6 +84,7 @@ def test_matmul(): assert (t1 @ t2).is_close(ans) assert t1(t2).is_close(ans) + def test_imatmul(): t1 = Transform2D(1, 2, 3, 4, 5, 6) t2 = Transform2D(3, 1, 2, 6, 5, 4) @@ -78,8 +92,12 @@ def test_imatmul(): t3 @= t2 assert t2 @ t1 == t3 + def test_determinant(): assert isclose(Transform2D(1.6, 2.5, 3.4, 4.3, 5.2, 6.1).determinant, -1.62) + def test_inverse(): - assert (~Transform2D(1, 2, 3, 4, 5, 6)).is_close(Transform2D(-2, 1, 1.5, -0.5, 1, -2)) + assert (~Transform2D(1, 2, 3, 4, 5, 6)).is_close( + Transform2D(-2, 1, 1.5, -0.5, 1, -2) + ) diff --git a/tests/test_transform_3d.py b/tests/test_transform_3d.py index 5d590da..7fc310b 100644 --- a/tests/test_transform_3d.py +++ b/tests/test_transform_3d.py @@ -11,6 +11,7 @@ def test_normal_constructor_and_get_components(): assert t.z == Vec3(7, 8, 9) assert t.origin == Vec3(10, 11, 12) + def test_comparison(): def gen_diff_at(index, n): return Transform3D(*[n if i == index else 0.5 for i in range(12)]) @@ -21,6 +22,7 @@ def gen_diff_at(index, n): result = gen_diff_at(a, num) == gen_diff_at(b, num) assert result if a == b else not result + def test_set_components(): t = Transform3D() t.x = Vec3(1, 2, 3) @@ -29,11 +31,16 @@ def test_set_components(): t.origin = Vec3(10, 11, 12) assert t == Transform3D(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) + def test_empty_constructor(): assert Transform3D() == Transform3D(1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0) + def test_component_constructor(): - assert Transform3D(Vec3(1, 2, 3), Vec3(4, 5, 6), Vec3(7, 8, 9), Vec3(10, 11, 12)) == Transform3D(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) + assert Transform3D( + Vec3(1, 2, 3), Vec3(4, 5, 6), Vec3(7, 8, 9), Vec3(10, 11, 12) + ) == Transform3D(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) + def test_copy_constructor(): t = Transform3D(*range(12)) @@ -41,18 +48,38 @@ def test_copy_constructor(): assert id(t) != id(tc) assert t == tc + def test_translation_constructor(): t = Transform3D.translating(Vec3(1, 2, 3)) assert t == Transform3D(1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 2, 3) + def test_rotating_constructor(): t = Transform3D.rotating(Vec3(1.85, 9.23, 3.42).normalized, 1.98) - assert t.is_close(Transform3D(-0.350185, 0.551229, -0.757309, -0.075323, 0.789313, 0.609353, 0.933647, 0.270429, -0.234886, 0, 0, 0), rel_tol=1e-5) + assert t.is_close( + Transform3D( + -0.350185, + 0.551229, + -0.757309, + -0.075323, + 0.789313, + 0.609353, + 0.933647, + 0.270429, + -0.234886, + 0, + 0, + 0, + ), + rel_tol=1e-5, + ) + def test_scaling_constructor(): t = Transform3D.scaling(Vec3(1, 2, 3)) assert t == Transform3D(1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0) + def test_vector_xform(): t = Transform3D(*range(1, 13)) v = Vec3(1, 2, 3) @@ -60,9 +87,11 @@ def test_vector_xform(): assert (t * v).is_close(ans) assert t(v).is_close(ans) + def test_vector_inverse_xform(): assert (Vec3(1, 2, 3) * Transform3D(*range(1, 13))).is_close(Vec3(-54, -135, -216)) + def test_matmul(): t1 = Transform3D(*range(1, 13)) t2 = Transform3D(10, 8, 4, 6, 12, 7, 5, 3, 2, 1, 11, 9) @@ -70,6 +99,7 @@ def test_matmul(): assert (t1 @ t2).is_close(ans) assert t1(t2).is_close(ans) + def test_imatmul(): t1 = Transform3D(*range(1, 13)) t2 = Transform3D(10, 8, 4, 6, 12, 7, 5, 3, 2, 1, 11, 9) @@ -77,8 +107,13 @@ def test_imatmul(): t3 @= t2 assert t2 @ t1 == t3 + def test_determinant(): - assert isclose(Transform3D(1.12, 2.11, 0, 0, 5.8, 6.7, 7.6, 8.5, 0, 1, 2, 3).determinant, 43.6572) + assert isclose( + Transform3D(1.12, 2.11, 0, 0, 5.8, 6.7, 7.6, 8.5, 0, 1, 2, 3).determinant, + 43.6572, + ) + def test_inverse(): t = Transform3D(1, 2, 3, 6, 5, 4, 7, 9, 8, 1, 2, 3) diff --git a/tests/test_vector.py b/tests/test_vector.py index 311f949..3d9c2dc 100644 --- a/tests/test_vector.py +++ b/tests/test_vector.py @@ -8,6 +8,7 @@ def test_empty_constructor(): vec = Vec3(0) assert vec.x == vec.y == vec.z == 0 + def test_normal_constructor(): vec = Vec3(1, 2, 3) assert vec.x == 1 @@ -19,6 +20,7 @@ def test_normal_constructor(): assert vec.y == 2 assert vec.z == 3 + def test_combined_constructors(): vec = Vec4(Vec2(1, 2), Vec2(3, 4)) assert vec.x == 1 @@ -31,16 +33,19 @@ def test_combined_constructors(): assert vec.y == 2 assert vec.z == 3 + def test_type_conversion_constructor(): vec = Vec3(Vec3i(1, 2, 3)) assert vec.x == 1 assert vec.y == 2 assert vec.z == 3 + def test_repr(): assert repr(Vec3(1, 2, 3)) == "Vec3(1.0, 2.0, 3.0)" assert repr(Vec3i(1, 2, 3)) == "Vec3i(1, 2, 3)" + def test_comparison(): a = Vec3(1, 2, 3) b = Vec3(3, 2, 1) @@ -50,6 +55,7 @@ def test_comparison(): assert (a == b) == False assert (a != c) == False + def test_swizzle(): v2 = Vec2(1, 2) v3 = Vec3(1, 2, 3) @@ -60,6 +66,7 @@ def test_swizzle(): assert v3.ylo == Vec3(2, 1, 0) assert v4.wzyx == Vec4(4, 3, 2, 1) + def test_copy(): a = Vec3(1, 2, 3) b = +a @@ -67,62 +74,75 @@ def test_copy(): assert a == Vec3(1, 2, 3) assert b == Vec3(5, 2, 3) + def test_neg(): a = Vec3(1, 2, 3) b = -a assert b == Vec3(-1, -2, -3) + def test_add(): a = Vec3(1, 2, 3) b = Vec3(4, 5, 6) assert a + b == Vec3(5, 7, 9) assert a == Vec3(1, 2, 3) + def test_iadd(): a = Vec3(1, 2, 3) b = Vec3(4, 5, 6) a += b assert a == Vec3(5, 7, 9) + def test_add_float(): a = Vec3(1, 2, 3) assert a + 1.5 == Vec3(2.5, 3.5, 4.5) + def test_add_int(): a = Vec3(1, 2, 3) assert a + 1 == Vec3(2, 3, 4) + def test_iadd_float(): a = Vec3(1, 2, 3) a += 1.5 assert a == Vec3(2.5, 3.5, 4.5) + def test_sub(): a = Vec3(1, 2, 3) b = Vec3(6, 5, 4) assert a - b == Vec3(-5, -3, -1) + def test_mul(): a = Vec3(1, 2, 3) b = Vec3(0, 2, 4) assert a * b == Vec3(0, 4, 12) + def test_div(): a = Vec3(1, 2, 3) b = Vec3(0.5, 2, 1.5) assert a / b == Vec3(2, 1, 2) + def test_div_zero(): a = Vec3(1, 2, 3) b = Vec3(0, 1, 2) import math + assert a / b == Vec3(math.inf, 2, 1.5) + def test_dot(): a = Vec3(1, 2, 3) b = Vec3(2, 1, 3) assert a @ b == 13.0 + def test_cross(): a = Vec3(1, 2, 3) b = Vec3(3, 7, 5) @@ -134,6 +154,7 @@ def test_cross(): # noinspection PyStatementEffect a ^ b + def test_len(): assert len(Vec2()) == 2 assert len(Vec3()) == 3 @@ -142,9 +163,11 @@ def test_len(): assert len(Vec3i()) == 3 assert len(Vec4i()) == 4 + def test_iter(): for i, value in enumerate(Vec3(1, 2, 3)): - assert value == i+1 + assert value == i + 1 + def test_unpack(): x, y, z = Vec3(1, 2, 3) @@ -152,35 +175,43 @@ def test_unpack(): assert y == 2 assert z == 3 + def test_getitem(): assert Vec3(0, 1.5, 0)[1] == 1.5 + def test_setitem(): v = Vec3(1, 3, 3) v[1] = 2 assert v == Vec3(1, 2, 3) + def test_length(): assert math.isclose(Vec3(1, 2, 3).length, 3.7416573867739413) + def test_length_squared(): assert Vec3(1, 2, 3).length_sqr == 14 + def test_distance_to(): a = Vec3(1, 2, 3) b = Vec3(4, 5, 6) assert math.isclose(a.distance_to(b), (a - b).length) + def test_distance_to_with_or_operator(): a = Vec3(1, 2, 3) b = Vec3(4, 5, 6) assert math.isclose(a | b, (a - b).length) + def test_distance_squared_to(): a = Vec3(1, 2, 3) b = Vec3(4, 5, 6) assert a.distance_sqr_to(b) == 27 + def test_normalized(): a = Vec3(1, 2, 3) a = a.normalized @@ -188,8 +219,10 @@ def test_normalized(): assert math.isclose(a.y, 0.5345224838248488) assert math.isclose(a.z, 0.8017837257372732) + def test_big_int(): - Vec2i(2**63-1) + Vec2i(2**63 - 1) + def test_float_precision(): assert Vec2(math.ulp(0)).x == math.ulp(0)