Skip to content

Commit

Permalink
51 problems getting started with readme example (#59)
Browse files Browse the repository at this point in the history
* added annotations from module with impunised functions and a check for aliases in Annotated Assignment nodes

* added tests for the module-level annotations

* removing legacy numpy random.rand calls

* refactoring typing for 3.12
  • Loading branch information
achevrot authored Nov 15, 2024
1 parent 66b8066 commit 137213c
Show file tree
Hide file tree
Showing 15 changed files with 88 additions and 47 deletions.
6 changes: 4 additions & 2 deletions scripts/performance/astropy_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import numpy as np

a: Annotated[Any, "meters"] = np.random.rand(100000) * u.m
b: Annotated[Any, "hours"] = np.random.rand(100000) * u.h
rng = np.random.default_rng()

a: Annotated[Any, "meters"] = rng.random(100000) * u.m
b: Annotated[Any, "hours"] = rng.random(100000) * u.h


@u.quantity_input(x=u.m, y=u.s)
Expand Down
6 changes: 3 additions & 3 deletions scripts/performance/astropy_noconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import numpy as np

np.random.seed(0)
rng = np.random.default_rng(0)

a: Annotated[Any, "meters"] = np.random.rand(100000) * u.m
b: Annotated[Any, "seconds"] = np.random.rand(100000) * u.s
a: Annotated[Any, "meters"] = rng.random(100000) * u.m
b: Annotated[Any, "seconds"] = rng.random(100000) * u.s


@u.quantity_input(x=u.m, y=u.s)
Expand Down
6 changes: 4 additions & 2 deletions scripts/performance/baseline_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

import numpy as np

a: Annotated[Any, "meters"] = np.random.rand(100000)
b: Annotated[Any, "hours"] = np.random.rand(100000) / 3600
rng = np.random.default_rng()

a: Annotated[Any, "meters"] = rng.random(100000)
b: Annotated[Any, "hours"] = rng.random(100000) / 3600


def g(
Expand Down
6 changes: 4 additions & 2 deletions scripts/performance/baseline_noconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import numpy as np

a: Annotated[Any, "meters"] = np.random.rand(100000)
b: Annotated[Any, "seconds"] = np.random.rand(100000)
rng = np.random.default_rng()

a: Annotated[Any, "meters"] = rng.random(100000)
b: Annotated[Any, "seconds"] = rng.random(100000)


def g(
Expand Down
10 changes: 7 additions & 3 deletions scripts/performance/impunity_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
import numpy as np
from impunity import impunity

a: Annotated[Any, "meters"] = np.random.rand(100000)
b: Annotated[Any, "hours"] = np.random.rand(100000)
rng = np.random.default_rng()

a: Annotated[Any, "meters"] = rng.random(100000)
b: Annotated[Any, "hours"] = rng.random(100000)


@impunity(rewrite="log.txt")
def g(x: Annotated[Any, "meters"], y: Annotated[Any, "hours"]) -> Annotated[Any, "m/s"]:
def g(
x: Annotated[Any, "meters"], y: Annotated[Any, "hours"]
) -> Annotated[Any, "m/s"]:
return x / y


Expand Down
6 changes: 4 additions & 2 deletions scripts/performance/impunity_noconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import numpy as np
from impunity import impunity

a: Annotated[Any, "meters"] = np.random.rand(100000)
b: Annotated[Any, "seconds"] = np.random.rand(100000)
rng = np.random.default_rng()

a: Annotated[Any, "meters"] = rng.random(100000)
b: Annotated[Any, "seconds"] = rng.random(100000)


@impunity(rewrite="log_no.txt")
Expand Down
6 changes: 4 additions & 2 deletions scripts/performance/numericalunits_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import numpy as np

a: Annotated[Any, "meters"] = np.random.rand(100000) * m
b: Annotated[Any, "hours"] = np.random.rand(100000) * hour
rng = np.random.default_rng()

a: Annotated[Any, "meters"] = rng.random(100000) * m
b: Annotated[Any, "hours"] = rng.random(100000) * hour


def g(
Expand Down
6 changes: 4 additions & 2 deletions scripts/performance/numericalunits_noconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import numpy as np

a: Annotated[Any, "meters"] = np.random.rand(100000) * m
b: Annotated[Any, "hours"] = np.random.rand(100000) * s
rng = np.random.default_rng()

a: Annotated[Any, "meters"] = rng.random(100000) * m
b: Annotated[Any, "hours"] = rng.random(100000) * s


def g(
Expand Down
5 changes: 3 additions & 2 deletions scripts/performance/pint_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

ureg = pint.UnitRegistry()
Q_ = ureg.Quantity
rng = np.random.default_rng()

a: Annotated[Any, "meters"] = Q_(np.random.rand(100000), "meter")
b: Annotated[Any, "seconds"] = Q_(np.random.rand(100000), "hours")
a: Annotated[Any, "meters"] = Q_(rng.random(100000), "meter")
b: Annotated[Any, "seconds"] = Q_(rng.random(100000), "hours")


@ureg.wraps(None, ("meter", "seconds"))
Expand Down
5 changes: 3 additions & 2 deletions scripts/performance/pint_noconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

ureg = pint.UnitRegistry()
Q_ = ureg.Quantity
rng = np.random.default_rng()

a: Annotated[Any, "meters"] = Q_(np.random.rand(100000), "meter")
b: Annotated[Any, "seconds"] = Q_(np.random.rand(100000), "seconds")
a: Annotated[Any, "meters"] = Q_(rng.random(100000), "meter")
b: Annotated[Any, "seconds"] = Q_(rng.random(100000), "seconds")


@ureg.wraps(None, ("meter", "seconds"))
Expand Down
6 changes: 4 additions & 2 deletions scripts/performance/quantities_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import numpy as np

a: Annotated[Any, "meters"] = np.random.rand(100000) * pq.m
b: Annotated[Any, "hours"] = np.random.rand(100000) * pq.h
rng = np.random.default_rng()

a: Annotated[Any, "meters"] = rng.random(100000) * pq.m
b: Annotated[Any, "hours"] = rng.random(100000) * pq.h


def g(
Expand Down
6 changes: 4 additions & 2 deletions scripts/performance/quantities_noconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import numpy as np

a: Annotated[Any, "meters"] = np.random.rand(100000) * pq.m
b: Annotated[Any, "seconds"] = np.random.rand(100000) * pq.s
rng = np.random.default_rng()

a: Annotated[Any, "meters"] = rng.random(100000) * pq.m
b: Annotated[Any, "seconds"] = rng.random(100000) * pq.s


def g(
Expand Down
33 changes: 13 additions & 20 deletions src/impunity/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ def get_annotation_unit(self, node: ast.expr) -> Optional[str]:
if isinstance(unit_node, ast.Constant):
unit = unit_node.value

# case where node's parent is an AnnAssign
elif isinstance(node, ast.Name):
if isinstance(node.parent, ast.AnnAssign): # type: ignore
unit = self.get_node_unit(node).unit

return unit

def visit(self, root: ast.AST) -> ast.AST:
Expand Down Expand Up @@ -264,9 +269,7 @@ def node_convert(
e10 = r10.to(expected_unit)

if r0.m == e0.m:
conv_value = expected_pint_unit.from_( # type: ignore
received_pint_unit
).m
conv_value = expected_pint_unit.from_(received_pint_unit).m # type: ignore

if conv_value == 1:
new_node = received_node
Expand All @@ -279,9 +282,7 @@ def node_convert(

elif (e1.m - e0.m) == 1:
conv_value = (
expected_pint_unit.from_( # type: ignore
received_pint_unit
).m
expected_pint_unit.from_(received_pint_unit).m # type: ignore
) - 1

# if conv_value == 0:
Expand Down Expand Up @@ -331,7 +332,9 @@ def module_loading(self) -> None:
self.vars[name] = anno

# Adding all annotations from imported modules
for _, val in self.fun_globals.items():
for var_name, val in self.fun_globals.items():
if is_annotated(val):
self.vars[var_name] = val.__metadata__[0]
if isinstance(val, types.ModuleType):
annotations = getattr(val, "__annotations__", {})
for name, anno in annotations.items():
Expand Down Expand Up @@ -430,9 +433,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
if decorator.func.id == "impunity":
for kw in decorator.keywords:
if hasattr(kw, "value"):
if (
kw.arg == "ignore" and kw.value.value # type: ignore
):
if kw.arg == "ignore" and kw.value.value: # type: ignore
self.impunity_func.pop(
(self.current_module, node.name), False
)
Expand Down Expand Up @@ -550,11 +551,7 @@ def get_node_unit(self, node: Optional[ast.expr]) -> QuantityNode:
)

if pint.Unit(left.unit).is_compatible_with(pint.Unit(right.unit)):
conv_value = (
pint.Unit(left.unit) # type: ignore
.from_(pint.Unit(right.unit))
.m
)
conv_value = pint.Unit(left.unit).from_(pint.Unit(right.unit)).m # type: ignore
new_node = ast.BinOp(
left.node, # type:ignore
node.op,
Expand Down Expand Up @@ -597,11 +594,7 @@ def get_node_unit(self, node: Optional[ast.expr]) -> QuantityNode:
right.unit = right.unit.__metadata__[0]

if pint.Unit(left.unit).is_compatible_with(pint.Unit(right.unit)):
conv_value = (
pint.Unit(left.unit) # type: ignore
.from_(pint.Unit(right.unit))
.m
)
conv_value = pint.Unit(left.unit).from_(pint.Unit(right.unit)).m # type: ignore
new_node = ast.BinOp(
left.node, # type: ignore
node.op,
Expand Down
14 changes: 14 additions & 0 deletions tests/sample_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@

from typing_extensions import Annotated

import numpy as np
import numpy.typing as npt
from impunity import impunity

NDArrayFloat = npt.NDArray[np.float64]
meters = Annotated[NDArrayFloat, "m"]
seconds = Annotated[float, "s"]
meters_per_second = Annotated[NDArrayFloat, "m/s"]


@impunity
def speed_to_test(
Expand All @@ -17,3 +24,10 @@ def speed_altitude_to_test(
d: Annotated[Any, "m"], t: Annotated[Any, "s"], a: Annotated[Any, "m"] = 0
) -> Annotated[Any, "m/s"]:
return d / t + a * 1.3654


@impunity
def speed_with_annotated_to_test(
distance: meters, time: seconds
) -> meters_per_second:
return distance / time
14 changes: 13 additions & 1 deletion tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import numpy as np
from impunity import impunity

from .sample_module import speed_altitude_to_test, speed_to_test
from .sample_module import (
speed_altitude_to_test,
speed_to_test,
speed_with_annotated_to_test,
)

m = Annotated[Any, "m"]
K = Annotated[Any, "K"]
Expand Down Expand Up @@ -216,6 +220,14 @@ def test_keyword(self) -> None:
res2 = speed_altitude_to_test(d, t)
self.assertAlmostEqual(res2, 1, delta=1e-2)

@impunity
def test_conversion_with_module(self) -> None:
# Using meters instead of Annotated[float, "m"]
altitudes: Annotated[Any, "meters"] = np.arange(0, 1000, 100)
duration: Annotated[float, "min"] = 100
result = speed_with_annotated_to_test(altitudes, duration)
self.assertAlmostEqual(result[3], 0.05, delta=1e-2)


if __name__ == "__main__":
unittest.main()

0 comments on commit 137213c

Please sign in to comment.