Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dynamical simulation unit systems (G=1 or other value) #225

Merged
merged 14 commits into from
Oct 17, 2024
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ build-backend = "hatchling.build"
"F821", # undefined name '...' <- jaxtyping
"FIX002", # Line contains TODO
"ISC001", # Conflicts with formatter
"N806", # Variable in function should be lowercase
"PD", # Pandas
"PLR09", # Too many <...>
"PLR2004", # Magic value used in comparison
Expand Down
54 changes: 49 additions & 5 deletions src/unxt/_src/units/system/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@

import astropy.units as u
import equinox as eqx
import numpy as np
from astropy.constants import G as const_G # noqa: N811, pylint: disable=E0611
from plum import dispatch

from . import builtin_dimensions as ud
from .base import UNITSYSTEMS_REGISTRY, AbstractUnitSystem
from .builtin import DimensionlessUnitSystem
from .flags import AbstractUnitSystemFlag, StandardUnitSystemFlag
from .flags import (
AbstractUSysFlag,
DynamicalSimUSysFlag,
StandardUSysFlag,
)
from .realizations import NAMED_UNIT_SYSTEMS, dimensionless
from .utils import get_dimension_name
from unxt._src.dimensions.core import dimensions_of
Expand Down Expand Up @@ -192,27 +199,64 @@ def unitsystem(usys: AbstractUnitSystem, *units_: Any) -> AbstractUnitSystem:


@dispatch # type: ignore[no-redef]
def unitsystem(flag: type[AbstractUnitSystemFlag], *_: Any) -> AbstractUnitSystem:
def unitsystem(flag: type[AbstractUSysFlag], *_: Any) -> AbstractUnitSystem:
"""Raise an exception since the flag is abstract."""
msg = "Do not use the AbstractUnitSystemFlag directly, only use subclasses."
msg = "Do not use the AbstractUSysFlag directly, only use subclasses."
raise TypeError(msg)


@dispatch # type: ignore[no-redef]
def unitsystem(flag: type[StandardUnitSystemFlag], *units_: Any) -> AbstractUnitSystem:
def unitsystem(flag: type[StandardUSysFlag], *units_: Any) -> AbstractUnitSystem:
"""Create a standard unit system using the inputted units.

Examples
--------
>>> import astropy.units as u
>>> from unxt import unitsystem, unitsystems
>>> unitsystem(unitsystems.StandardUnitSystemFlag, u.kpc, u.Myr, u.Msun)
>>> unitsystem(unitsystems.StandardUSysFlag, u.kpc, u.Myr, u.Msun)
LengthTimeMassUnitSystem(length=Unit("kpc"), time=Unit("Myr"), mass=Unit("solMass"))

"""
return unitsystem(*units_)


@dispatch # type: ignore[no-redef]
def unitsystem(
flag: type[DynamicalSimUSysFlag],
*units_: Any,
G: float | int = 1.0, # noqa: N803
) -> AbstractUnitSystem:
tmp = unitsystem(*units_)

# Use G for computing the missing units below:
G = G * const_G

added = ()
if ud.length in tmp.base_dimensions and ud.mass in tmp.base_dimensions:
adrn marked this conversation as resolved.
Show resolved Hide resolved
time = 1 / np.sqrt(G * tmp["mass"] / tmp["length"] ** 3)
added = (time,)
elif ud.length in tmp.base_dimensions and ud.time in tmp.base_dimensions:
mass = 1 / G * tmp["length"] ** 3 / tmp["time"] ** 2
added = (mass,)
elif ud.length in tmp.base_dimensions and ud.speed in tmp.base_dimensions:
time = tmp["length"] / tmp["velocity"]
mass = tmp["velocity"] ** 2 / G * tmp["length"]
added = (time, mass)
elif ud.mass in tmp.base_dimensions and ud.time in tmp.base_dimensions:
length = np.cbrt(G * tmp["mass"] * tmp["time"] ** 2)
added = (length,)
elif ud.mass in tmp.base_dimensions and ud.speed in tmp.base_dimensions:
length = G * tmp["mass"] / tmp["velocity"] ** 2
time = length / tmp["velocity"]
added = (length, time)
elif ud.time in tmp.base_dimensions and ud.speed in tmp.base_dimensions:
mass = 1 / G * tmp["velocity"] ** 3 * tmp["time"]
length = G * mass / tmp["velocity"] ** 2
added = (mass, length)

return unitsystem(*tmp, *added)


# ----


Expand Down
16 changes: 12 additions & 4 deletions src/unxt/_src/units/system/flags.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
__all__ = ["AbstractUnitSystemFlag", "StandardUnitSystemFlag"]
__all__ = [
"AbstractUSysFlag",
"StandardUSysFlag",
"DynamicalSimUSysFlag",
]

from typing import Any


class AbstractUnitSystemFlag:
class AbstractUSysFlag:
"""Abstract class for unit system flags to provide dispatch control."""

def __new__(cls, *_: Any, **__: Any) -> None: # type: ignore[misc]
msg = "unit system flag classes cannot be instantiated."
raise ValueError(msg)


class StandardUnitSystemFlag(AbstractUnitSystemFlag):
"""Unit system flag to indicate a standard unit system with no additional args."""
class StandardUSysFlag(AbstractUSysFlag):
"""Flag to indicate a standard unit system with no additional arguments."""


class DynamicalSimUSysFlag(AbstractUSysFlag):
"""Flag to indicate a unit system with optional definition of G."""
38 changes: 32 additions & 6 deletions tests/unit/test_unitsystems.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test the `unxt.unitsystems` module."""

import itertools
import pickle
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -13,9 +14,10 @@
from unxt._src.units.system.base import _UNITSYSTEMS_REGISTRY
from unxt.unitsystems import (
AbstractUnitSystem,
AbstractUnitSystemFlag,
AbstractUSysFlag,
DimensionlessUnitSystem,
StandardUnitSystemFlag,
DynamicalSimUSysFlag,
StandardUSysFlag,
dimensionless,
equivalent,
unitsystem,
Expand Down Expand Up @@ -202,17 +204,41 @@ def test_extend():
def test_abstract_usys_flag():
"""Test that the abstract unit system flag fails."""
with pytest.raises(TypeError, match="Do not use"):
unitsystem(AbstractUnitSystemFlag, u.kpc)
unitsystem(AbstractUSysFlag, u.kpc)

with pytest.raises(ValueError, match="unit system flag classes"):
AbstractUnitSystemFlag()
AbstractUSysFlag()


def test_standard_flag():
"""Test defining unit system with the standard flag."""
usys1 = unitsystem(StandardUnitSystemFlag, u.kpc, u.Myr)
usys1 = unitsystem(StandardUSysFlag, u.kpc, u.Myr)
usys2 = unitsystem(u.kpc, u.Myr)
assert usys1 == usys2

with pytest.raises(ValueError, match="unit system flag classes"):
StandardUnitSystemFlag()
StandardUSysFlag()


def test_simulation_usys():
"""Test defining the simulation unit system with expected inputs."""
from astropy.constants import G as const_G # noqa: N811

tmp_G = const_G.decompose([u.kpc, u.Myr, u.Msun])
usys1 = unitsystem(DynamicalSimUSysFlag, u.kpc, u.Myr, u.rad)
assert np.isclose((1 * usys1["mass"]).to_value(u.Msun), 1 / tmp_G.value)

usys2 = unitsystem(DynamicalSimUSysFlag, u.kpc, u.Msun, u.rad)
assert np.isclose((1 * usys2["time"]).to_value(u.Myr), 1 / np.sqrt(tmp_G.value))

base_units = (u.kpc, u.Myr, u.Msun, u.km / u.s)
for u1, u2 in itertools.product(base_units, base_units):
if u1 == u2:
continue

usys = unitsystem(DynamicalSimUSysFlag, u1, u2)

# For now, just test retrieving all three base unit types:
usys["length"]
usys["mass"]
usys["time"]
6 changes: 2 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.