coordinax.internal#

Warning

Everything in coordinax.internal is semi-public. The APIs exposed here are usable by downstream packages but are not covered by the same stability guarantees as the top-level coordinax API. Names, signatures, and behaviour may change at any time without warning in minor or patch releases. Pin to an exact version if you depend on anything here.

coordinax.internal re-exports selected internal utilities that are useful for advanced users and downstream library authors, but whose interfaces are not yet stable enough for the main public API.

Overview#

The module currently provides two kinds of semi-public helpers:

  • heterogeneous unit containers for vectors and matrices

  • packing helpers for converting component dictionaries to arrays and back

These utilities are primarily useful when implementing downstream transforms, Jacobians, metric-like objects, or other chart-aware machinery that needs to preserve per-component physical units.

Quick Start#

import jax.numpy as jnp
import unxt as u
from coordinax.internal import QMatrix

J = QMatrix(
    value=jnp.eye(3),
    unit=(
        (u.unit("m/m"), u.unit("m/rad"), u.unit("m/rad")),
        (u.unit("rad/m"), u.unit("rad/rad"), u.unit("rad/rad")),
        (u.unit("rad/m"), u.unit("rad/rad"), u.unit("rad/rad")),
    ),
)

QMatrix supports both 1-D and 2-D cases. This makes it suitable for heterogeneous vectors as well as Jacobians and metric tensors whose entries do not all share the same unit.

Packing Helpers#

import unxt as u
import coordinax.charts as cxc
from coordinax.internal import pack_nonuniform_unit, pack_uniform_unit

p = {"x": u.Q(1, "km"), "y": u.Q(200, "m"), "z": u.Q(3, "km")}

vals, unit = pack_uniform_unit(p, ("x", "y", "z"))
restored = cxc.cdict(vals, unit, ("x", "y", "z"))

vals2, units2 = pack_nonuniform_unit(p, ("x", "y", "z"))

Use pack_uniform_unit when all components should be expressed in a shared unit before stacking into an array. Use pack_nonuniform_unit when each component should retain its own unit metadata.

Functional API#

  • cdict_units: extract per-key units from a component dictionary

  • pack_uniform_unit: stack a component dictionary into an array using a shared reference unit

  • pack_nonuniform_unit: stack a component dictionary into an array while preserving a per-component unit tuple

Available Objects#

Heterogeneous Unit Containers#

  • QMatrix: N-D quantity container with per-element units; currently supports 1-D vectors and 2-D matrices

  • UnitsMatrix: immutable nested tuple of units with tuple-style indexing and shape metadata

Packing Utilities#

  • cdict_units: unit introspection helper for component dictionaries

  • pack_uniform_unit: pack values into an array with one shared unit

  • pack_nonuniform_unit: pack values into an array with per-component units

Notes#

  • This module is intended for advanced use and downstream integration, not as a stable top-level user API.

  • The exported helpers are especially useful when chart components do not all share the same physical dimension.

  • For stable end-user coordinate functionality, prefer the top-level coordinax API and its public submodules.

coordinax.internal โ€” semi-public utilities.

Warning

Everything in this module is semi-public. The APIs exposed here are usable by downstream packages but are not covered by the same stability guarantees as the top-level coordinax API. Names, signatures, and behaviour may change at any time without warning in minor or patch releases. Pin to an exact version if you depend on anything here.

Contents:

  • QMatrix

    An N-D quantity matrix/vector where every element carries its own unit. Supports both 1-D (vector) and 2-D (matrix) cases. Useful for Jacobians and metric tensors whose entries have heterogeneous physical dimensions.

  • UnitsMatrix

    Nested tuple of units with indexing support for 1-D, 2-D (and N-D).

  • pack_uniform_unit

    Pack dict-of-quantities into an array, converting all entries to a common unit.

  • tree_cast_int_bool_to_float

    Tree-map over a PyTree, promoting integer and boolean leaves to the default floating-point dtype (jax.dtypes.canonicalize_dtype(jnp.float_)). Existing float and complex leaves are left unchanged. Useful for satisfying jax.jacfwdโ€™s requirement of real-floating inputs.

  • structured

    Decorator for transparent argument and return value processing. This helps pushing the logic for packing/unpacking inside a JIT.

class coordinax.internal.QMatrix(value: Shaped[Array, '...'], unit: Any)#

Bases: AbstractQuantity

Quantity container whose elements may each carry different units.

QMatrix stores one numeric array together with a static UnitsMatrix describing the unit of each logical element. The shape of the unit structure determines whether the object behaves as a heterogeneous vector or matrix.

Only 1-D and 2-D logical structures are supported.

Parameters:
  • value (']) โ€“ Numeric payload. For 1D: (..., N). For 2D: (..., N, M). The value of element [i] (1D) or [i, j] (2D) is expressed in the corresponding unit.

  • unit (Any) โ€“ Per-element units. For 1D: (u0, u1, ...). For 2D: ((u00, u01, ...), (u10, u11, ...), ...). Must be a static (hashable) nested tuple structure whose shape matches the trailing dimensions of value.

Examples

>>> import jax.numpy as jnp
>>> import unxt as u
>>> from coordinax.internal import QMatrix

1D case (vector):

>>> qv = QMatrix(jnp.array([1.0, 2.0, 3.0]), unit=("m", "s", "kg"))
>>> qv.value
Array([1., 2., 3.], dtype=float64)
>>> qv.unit.shape
(3,)
>>> 2 * qv
QMatrix([2., 4., 6.], '(m, s, kg)')
>>> qv2 = QMatrix(jnp.array([0.1, 200.0, 300.0]), unit=("km", "ms", "g"))
>>> qv + qv2
QMatrix([101. ,   2.2,   3.3], '(m, s, kg)')

2D case (matrix):

>>> qm = QMatrix(jnp.ones((2, 2)), unit=(("m", "s"), ("kg", "rad")))
>>> qm.value.shape
(2, 2)
>>> qm.unit.shape
(2, 2)
>>> 2 * qm
QMatrix([[2., 2.],
                [2., 2.]], '((m, s), (kg, rad))')
>>> qm2 = QMatrix(jnp.array([[0.1, 200.0], [300.0, 0.5]]),
...                      unit=(("km", "ms"), ("g", "deg")))
>>> qm + qm2
QMatrix([[101.        ,   1.2       ],
                [  1.3       ,   1.00872665]], '((m, s), (kg, rad))')

Indexing:

>>> qv[0]
Q(1., 'm')
>>> qm[0]
QMatrix([1., 1.], '(m, s)')
>>> qm[1, 0]
Q(1., 'kg')
value: Shaped[Array, '...']#

The value of the AbstractQuantity.

unit: UnitsMatrix#

The unit associated with this value.

property ndim: int#

Number of real dimensions (1 for vector, 2 for matrix).

property shape: tuple[int, ...]#

Shape, including batch dimensions.

classmethod from_cdict(v: dict[str, Any], /, keys: tuple[str, ...] | None = None)#

Pack a component dictionary into a 1-D QMatrix.

Each value in v is stripped to its numeric value and stacked into a single JAX array. Values that carry units (unxt.Quantity) retain those units in the resulting UnitsMatrix; plain arrays are treated as dimensionless.

Examples

>>> import unxt as u
>>> from coordinax.internal import QMatrix

From a dictionary of quantities:

>>> v = {"x": u.Q(1.0, "m"), "y": u.Q(2.0, "s"), "z": u.Q(3.0, "kg")}
>>> qv = QMatrix.from_cdict(v)
>>> qv.unit.to_string()
'(m, s, kg)'
>>> qv.value
Array([1., 2., 3.], dtype=float64, ...)

Selecting and reordering a subset of keys:

>>> qv2 = QMatrix.from_cdict(v, keys=("z", "x"))
>>> qv2.unit.to_string()
'(kg, m)'
>>> qv2.value
Array([3., 1.], dtype=float64, ...)

Dimensionless entries (bare arrays) are accepted:

>>> import jax.numpy as jnp
>>> v2 = {"a": jnp.array(4.0), "b": u.Q(5.0, "m")}
>>> qv3 = QMatrix.from_cdict(v2)
>>> qv3.unit.to_string()
'(, m)'
Parameters:
Return type:

QMatrix

diag()#

Return a 1-D QMatrix containing the diagonal of this matrix.

Unlike qnp.diag, this method operates directly on the static unit structure and the raw value array, so it works correctly under jax.jit and with heterogeneous-unit matrices.

Only supported for 2-D QMatrix objects.

Returns:

1-D QMatrix of length min(n_rows, n_cols) whose unit[i] is self.unit[i, i] and whose value[..., i] is self.value[..., i, i].

Return type:

QMatrix

Examples

>>> import jax.numpy as jnp
>>> from coordinax.internal import QMatrix

Uniform units:

>>> A = QMatrix(jnp.diag(jnp.array([1.0, 4.0, 9.0])),
...                    unit=(("m", "m", "m"), ("m", "m", "m"), ("m", "m", "m")))
>>> d = A.diag()
>>> d.unit.shape
(3,)
>>> d.value
Array([1., 4., 9.], dtype=float64)

Heterogeneous units โ€” works under jit:

>>> B = QMatrix(jnp.diag(jnp.array([1.0, 2.0, 3.0])),
...                    unit=(("m", "s", "kg"),
...                          ("m", "s", "kg"),
...                          ("m", "s", "kg")))
>>> db = B.diag()
>>> db.unit.to_string()
'(m, s, kg)'
>>> db.value
Array([1., 2., 3.], dtype=float64)
property T: QMatrix#

Transpose a 2-D QMatrix (swap rows/columns and units).

Returns a new QMatrix whose value array and unit structure are both transposed. Only 2-D matrices are supported.

Examples

>>> import jax.numpy as jnp
>>> import quaxed.numpy as qnp
>>> from coordinax.internal import QMatrix
>>> a = QMatrix(jnp.array([[1.0, 2.0], [3.0, 4.0]]),
...                    unit=(("m", "s"), ("kg", "rad")))
>>> aT = a.T
>>> aT.value
Array([[1., 3.],
       [2., 4.]], dtype=float64)
>>> aT.unit.to_string()
'((m, kg), (s, rad))'

Also accessible via jax.numpy.transpose:

>>> aT2 = qnp.matrix_transpose(a)
>>> aT2.value
Array([[1., 3.],
       [2., 4.]], dtype=float64)
>>> aT2.unit.to_string()
'((m, kg), (s, rad))'
argmax(*args: Any, **kwargs: Any)#

Return the indices of the maximum value.

Examples

>>> import unxt as u
>>> q = u.Quantity([1, 2, 3], "m")
>>> q.argmax()
Array(2, dtype=int32)
Parameters:
Return type:

Array

argmin(*args: Any, **kwargs: Any)#

Return the indices of the minimum value.

Examples

>>> import unxt as u
>>> q = u.Quantity([1, 2, 3], "m")
>>> q.argmin()
Array(0, dtype=int32)
Parameters:
Return type:

Array

astype(*args: Any, **kwargs: Any)#

Copy the array and cast to a specified dtype.

Examples

>>> import unxt as u
>>> q = u.Quantity([1, 2, 3], "m")
>>> q.dtype
dtype('int32')
>>> q.astype(float)
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
Parameters:
Return type:

AbstractQuantity

property at: _QuantityIndexUpdateHelper#

Helper property for index update functionality.

The at property provides a functionally pure equivalent of in-place array modifications.

In particular:

Alternate syntax

Equivalent In-place expression

x = x.at[idx].set(y)

x[idx] = y

x = x.at[idx].add(y)

x[idx] += y

x = x.at[idx].subtract(y)

x[idx] -= y

x = x.at[idx].multiply(y)

x[idx] *= y

x = x.at[idx].divide(y)

x[idx] /= y

x = x.at[idx].power(y)

x[idx] **= y

x = x.at[idx].min(y)

x[idx] = minimum(x[idx], y)

x = x.at[idx].max(y)

x[idx] = maximum(x[idx], y)

x = x.at[idx].apply(ufunc)

ufunc.at(x, idx)

x = x.at[idx].get()

x = x[idx]

None of the x.at expressions modify the original x; instead they return a modified copy of x. However, inside a jit() compiled function, expressions like x = x.at[idx].set(y) are guaranteed to be applied in-place.

Unlike NumPy in-place operations such as x[idx] += y, if multiple indices refer to the same location, all updates will be applied (NumPy would only apply the last update, rather than applying all updates.) The order in which conflicting updates are applied is implementation-defined and may be nondeterministic (e.g., due to concurrency on some hardware platforms).

By default, JAX assumes that all indices are in-bounds. Alternative out-of-bound index semantics can be specified via the mode parameter (see below).

Parameters:
  • mode โ€“

    string specifying out-of-bound indexing mode. Options are:

    • "promise_in_bounds": (default) The user promises that indices are in bounds. No additional checking will be performed. In practice, this means that out-of-bounds indices in get() will be clipped, and out-of-bounds indices in set(), add(), etc. will be dropped.

    • "clip": clamp out of bounds indices into valid range.

    • "drop": ignore out-of-bound indices.

    • "fill": alias for "drop". For get(), the optional fill_value argument specifies the value that will be returned.

    See jax.lax.GatherScatterMode for more details.

  • wrap_negative_indices โ€“ If True (default) then negative indices indicate position from the end of the array, similar to Python and NumPy indexing. If False, then negative indices are considered out-of-bounds and behave according to the mode parameter.

  • fill_value โ€“ Only applies to the get() method: the fill value to return for out-of-bounds slices when mode is 'fill'. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.

  • indices_are_sorted โ€“ If True, the implementation will assume that the (normalized) indices passed to at[] are sorted in ascending order, which can lead to more efficient execution on some backends. If True but the indices are not actually sorted, the output is undefined.

  • unique_indices โ€“ If True, the implementation will assume that the (normalized) indices passed to at[] are unique, which can result in more efficient execution on some backends. If True but the indices are not actually unique, the output is undefined.

Examples

>>> x = jnp.arange(5.0)
>>> x
Array([0., 1., 2., 3., 4.], dtype=float32)
>>> x.at[2].get()
Array(2., dtype=float32)
>>> x.at[2].add(10)
Array([ 0.,  1., 12.,  3.,  4.], dtype=float32)

By default, out-of-bound indices are ignored in updates, but this behavior can be controlled with the mode parameter:

>>> x.at[10].add(10)  # dropped
Array([0., 1., 2., 3., 4.], dtype=float32)
>>> x.at[20].add(10, mode='clip')  # clipped
Array([ 0.,  1.,  2.,  3., 14.], dtype=float32)

For get(), out-of-bound indices are clipped by default:

>>> x.at[20].get()  # out-of-bounds indices clipped
Array(4., dtype=float32)
>>> x.at[20].get(mode='fill')  # out-of-bounds indices filled with NaN
Array(nan, dtype=float32)
>>> x.at[20].get(mode='fill', fill_value=-1)  # custom fill value
Array(-1., dtype=float32)

Negative indices count from the end of the array, but this behavior can be disabled by setting wrap_negative_indices = False:

>>> x.at[-1].set(99)
Array([ 0.,  1.,  2.,  3., 99.], dtype=float32)
>>> x.at[-1].set(99, wrap_negative_indices=False, mode='drop')  # dropped!
Array([0., 1., 2., 3., 4.], dtype=float32)
block_until_ready()#

Block until the array is ready.

Return type:

AbstractQuantity

Examples

>>> import unxt as u
>>> q = u.Quantity(1, "m")
>>> q.block_until_ready() is q
True
decompose(bases: Sequence[Unit | UnitBase | CompositeUnit | str], /)#

Decompose the quantity into the given bases.

Examples

>>> from unxt import Quantity
>>> q = Quantity(1, "m")
>>> q.decompose(["cm", "s"])
Quantity(Array(100., dtype=float32, ...), unit='cm')
Parameters:

bases (Sequence[Unit | UnitBase | CompositeUnit | str])

Return type:

AbstractQuantity

property device: Device#

Device where the array is located.

Examples

>>> import unxt as u
>>> u.Quantity(1, "m").device
CpuDevice(id=0)
devices()#

Return the devices where the array is located.

Return type:

set[Device]

Examples

>>> import unxt as u
>>> q = u.Quantity(1, "m")
>>> q.devices()
{CpuDevice(id=0)}
property dtype: dtype#

Data type of the array.

Examples

>>> import unxt as u
>>> u.Quantity(1, "m").dtype
dtype('int32')
flatten()#

Return a flattened version of the array.

Return type:

AbstractQuantity

Examples

>>> import unxt as u
>>> q = u.Quantity([[1, 2], [3, 4]], "m")
>>> q.flatten()
Quantity(Array([1, 2, 3, 4], dtype=int32), unit='m')
classmethod from_(cls: type[AbstractQuantity], *args: Any, **kwargs: Any)#
from_(cls: type[AbstractQuantity], value: ArrayLike | list[jaxtyping.Shaped[Array, ''] | jaxtyping.Shaped[ndarray, ''] | bool | number | bool | int | float | complex] | tuple[jaxtyping.Shaped[Array, ''] | jaxtyping.Shaped[ndarray, ''] | bool | number | bool | int | float | complex, ...], unit: Any, /, *, dtype: Any = None) AbstractQuantity
Parameters:
Return type:

AbstractQuantity

Construct a unxt.Quantity from an array-like value and a unit.

Parameters:
  • value โ€“ The array-like value.

  • unit โ€“ The unit of the value.

  • dtype โ€“ The data type of the array (keyword-only).

  • args (Any)

  • kwargs (Any)

Return type:

AbstractQuantity

Examples

For this example weโ€™ll use the Quantity class. The same applies to any subclass of AbstractQuantity.

>>> import jax.numpy as jnp
>>> import unxt as u
>>> x = jnp.array([1.0, 2, 3])
>>> u.Quantity.from_(x, "m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
>>> u.Quantity.from_([1.0, 2, 3], "m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
>>> u.Quantity.from_((1.0, 2, 3), "m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
from_(cls: type[AbstractQuantity], value: ArrayLike | list[jaxtyping.Shaped[Array, ''] | jaxtyping.Shaped[ndarray, ''] | bool | number | bool | int | float | complex] | tuple[jaxtyping.Shaped[Array, ''] | jaxtyping.Shaped[ndarray, ''] | bool | number | bool | int | float | complex, ...], /, *, unit: Any, dtype: Any = None) AbstractQuantity
Parameters:
Return type:

AbstractQuantity

Make a unxt.AbstractQuantity from an array-like value and a unit kwarg.

Examples

For this example weโ€™ll use the unxt.Quantity class. The same applies to any subclass of unxt.AbstractQuantity.

>>> import unxt as u
>>> u.Quantity.from_([1.0, 2, 3], unit="m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
from_(cls: type[AbstractQuantity], *, value: Any, unit: Any, dtype: Any = None) AbstractQuantity
Parameters:
Return type:

AbstractQuantity

Construct a AbstractQuantity from value and unit kwargs.

Examples

For this example weโ€™ll use the Quantity class. The same applies to any subclass of AbstractQuantity.

>>> import unxt as u
>>> u.Quantity.from_(value=[1.0, 2, 3], unit="m")
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
from_(cls: type[AbstractQuantity], mapping: Mapping[str, Any]) AbstractQuantity
Parameters:
Return type:

AbstractQuantity

Construct a Quantity from a Mapping.

Examples

For this example weโ€™ll use the Quantity class. The same applies to any subclass of AbstractQuantity.

>>> import jax.numpy as jnp
>>> import unxt as u
>>> x = jnp.array([1.0, 2, 3])
>>> q = u.Quantity.from_({"value": x, "unit": "m"})
>>> q
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
>>> u.Quantity.from_({"value": q, "unit": "km"})
Quantity(Array([0.001, 0.002, 0.003], dtype=float32), unit='km')
from_(cls: type[AbstractQuantity], value: AbstractQuantity, unit: Any, /, *, dtype: Any = None) AbstractQuantity
Parameters:
Return type:

AbstractQuantity

Construct a Quantity from another Quantity.

The value is converted to the new unit.

Examples

>>> import unxt as u
>>> q = u.Quantity(1, "m")
>>> u.Quantity.from_(q, "cm")
Quantity(Array(100., dtype=float32, ...), unit='cm')
from_(cls: type[AbstractQuantity], value: AbstractQuantity, unit: NoneType, /, *, dtype: Any = None) AbstractQuantity
Parameters:
Return type:

AbstractQuantity

Construct a Quantity from another Quantity.

The value is converted to the new unit.

Examples

>>> import unxt as u
>>> q = u.Quantity(1, "m")
>>> u.Quantity.from_(q, None)
Quantity(Array(1, dtype=int32, ...), unit='m')
from_(cls: type[AbstractQuantity], value: AbstractQuantity, /, *, unit: Any | None = None, dtype: Any = None) AbstractQuantity
Parameters:
Return type:

AbstractQuantity

Construct a Quantity from another Quantity, with no unit change.

from_(cls: type[AbstractQuantity], value: Quantity, /, **kwargs: Any) AbstractQuantity
Parameters:
Return type:

AbstractQuantity

Construct a Quantity from another Quantity.

The value is converted to the new unit.

Examples

>>> import unxt as u
>>> import astropy.units as apyu
>>> u.Quantity.from_(apyu.Quantity(1, "m"))
Quantity(Array(1., dtype=float32), unit='m')
from_(cls: type[AbstractQuantity], value: Quantity, u: Any, /, **kwargs: Any) AbstractQuantity
Parameters:
Return type:

AbstractQuantity

Construct a Quantity from another Quantity.

The value is converted to the new unit.

Examples

>>> import unxt as u
>>> import astropy.units as apyu
>>> u.Quantity.from_(apyu.Quantity(1, "m"), "cm")
Quantity(Array(100., dtype=float32), unit='cm')
from_(cls: type[Distance], value: ArrayLike, unit: Any, /, **kw: Any) Distance
Parameters:
Return type:

AbstractQuantity

Construct a distance.

>>> import unxt as u
>>> import coordinax.distances as cxd
>>> cxd.Distance.from_(1, "kpc")
Distance(1, 'kpc')
from_(cls: type[Distance], d: Distance, /, **kw: Any) Distance
Parameters:
Return type:

AbstractQuantity

Compute distance from distance.

>>> import unxt as u
>>> import coordinax.distances as cxd
>>> d = cxd.Distance(1, "kpc")
>>> cxd.Distance.from_(d) is d
True
>>> cxd.Distance.from_(d, dtype=float)
Distance(1., 'kpc')
from_(cls: type[Distance], d: Quantity[PhysicalType('length')], /, **kw: Any) Distance
Parameters:
Return type:

AbstractQuantity

Compute distance from distance.

>>> import unxt as u
>>> import coordinax.distances as cxd
>>> q = u.Q(1, "kpc")
>>> cxd.Distance.from_(q, dtype=float)
Distance(1., 'kpc')
from_(cls: type[Distance], p: Quantity[PhysicalType('angle')], /, **kw: Any) Distance
Parameters:
Return type:

AbstractQuantity

Compute distance from parallax.

>>> import unxt as u
>>> import coordinax.distances as cxd
>>> q = u.Q(1, "mas")
>>> cxd.Distance.from_(q).uconvert("pc").round(2)
Distance(1000., 'pc')
from_(cls: type[Distance], dm: Quantity[PhysicalType('unknown')], /, **kw: Any) Distance
Parameters:
Return type:

AbstractQuantity

Compute distance from distance modulus.

>>> import unxt as u
>>> import coordinax.distances as cxd
>>> q = u.Q(10, "mag")
>>> cxd.Distance.from_(q).uconvert("pc").round(2)
Distance(1000., 'pc')
from_(cls: type[DistanceModulus], value: ArrayLike, unit: Any, /, **kw: Any) DistanceModulus
Parameters:
Return type:

AbstractQuantity

Construct a distance.

>>> import unxt as u
>>> from coordinax.astro import DistanceModulus
>>> DistanceModulus.from_(1, "mag")
DistanceModulus(1, 'mag')
from_(cls: type[DistanceModulus], dm: DistanceModulus, /, **kw: Any) DistanceModulus
Parameters:
Return type:

AbstractQuantity

Compute distance modulus from distance modulus.

>>> import unxt as u
>>> from coordinax.astro import DistanceModulus
>>> dm = DistanceModulus(1, "mag")
>>> DistanceModulus.from_(dm) is dm
True
>>> DistanceModulus.from_(dm, dtype=float)
DistanceModulus(1., 'mag')
from_(cls: type[DistanceModulus], dm: Quantity[PhysicalType('unknown')], /, **kw: Any) DistanceModulus
Parameters:
Return type:

AbstractQuantity

Compute parallax from parallax.

>>> import unxt as u
>>> from coordinax.astro import DistanceModulus
>>> q = u.Q(1, "mag")
>>> DistanceModulus.from_(q)
DistanceModulus(1, 'mag')
from_(cls: type[DistanceModulus], d: Distance, /, **kw: Any) DistanceModulus
Parameters:
Return type:

AbstractQuantity

Compute distance modulus from distance.

>>> import coordinax.distances as cxd
>>> from coordinax.astro import DistanceModulus
>>> d = cxd.Distance(1, "pc")
>>> DistanceModulus.from_(d)
DistanceModulus(-5., 'mag')
from_(cls: type[DistanceModulus], d: Quantity[PhysicalType('length')], /, **kw: Any) DistanceModulus
Parameters:
Return type:

AbstractQuantity

Compute distance modulus from distance.

>>> import unxt as u
>>> from coordinax.astro import DistanceModulus
>>> q = u.Q(1, "pc")
>>> DistanceModulus.from_(q)
DistanceModulus(-5., 'mag')
from_(cls: type[DistanceModulus], p: Quantity[PhysicalType('angle')], /, **kw: Any) DistanceModulus
Parameters:
Return type:

AbstractQuantity

Compute distance modulus from parallax.

>>> import unxt as u
>>> from coordinax.astro import DistanceModulus
>>> q = u.Q(1, "mas")
>>> DistanceModulus.from_(q)
DistanceModulus(10., 'mag')
from_(cls: type[Distance], dm: DistanceModulus, /, **kw: Any) Distance
Parameters:
Return type:

AbstractQuantity

Compute distance from distance modulus.

>>> import coordinax.distances as cxd
>>> from coordinax.astro import DistanceModulus
>>> dm = DistanceModulus(10, "mag")
>>> cxd.Distance.from_(dm).uconvert("pc").round(2)
Distance(1000., 'pc')
from_(cls: type[Parallax], value: ArrayLike, unit: Any, /, **kw: Any) Parallax
Parameters:
Return type:

AbstractQuantity

Construct a distance.

>>> import unxt as u
>>> from coordinax.astro import Parallax
>>> Parallax.from_(1, "mas")
Parallax(1, 'mas')
from_(cls: type[Parallax], p: Parallax, /, **kw: Any) Parallax
Parameters:
Return type:

AbstractQuantity

Compute parallax from parallax.

>>> import unxt as u
>>> from coordinax.astro import Parallax
>>> p = Parallax(1, "mas")
>>> Parallax.from_(p) is p
True
>>> Parallax.from_(p, dtype=float)
Parallax(1., 'mas')
from_(cls: type[Parallax], p: Quantity[PhysicalType('angle')], /, **kw: Any) Parallax
Parameters:
Return type:

AbstractQuantity

Compute parallax from parallax.

>>> import unxt as u
>>> from coordinax.astro import Parallax
>>> q = u.Q(1, "mas")
>>> Parallax.from_(q, dtype=float)
Parallax(1., 'mas')
from_(cls: type[Parallax], d: Distance | Quantity[PhysicalType('length')], /, **kw: Any) Parallax
Parameters:
Return type:

AbstractQuantity

Compute parallax from distance.

>>> import unxt as u
>>> from coordinax.astro import Parallax
>>> d = cxd.Distance(10, "pc")
>>> Parallax.from_(d).uconvert("mas").round(2)
Parallax(100., 'mas')
>>> q = u.Q(10, "pc")
>>> Parallax.from_(q).uconvert("mas").round(2)
Parallax(100., 'mas')
from_(cls: type[Parallax], dm: Quantity[PhysicalType('unknown')], /, **kw: Any) Parallax
Parameters:
Return type:

AbstractQuantity

Convert distance modulus to parallax.

>>> import unxt as u
>>> from coordinax.astro import Parallax
>>> dm = u.Q(10, "mag")
>>> Parallax.from_(dm).uconvert("mas").round(2)
Parallax(1., 'mas')
from_(cls: type[Distance], p: Parallax, /, **kw: Any) Distance
Parameters:
Return type:

AbstractQuantity

Compute distance from parallax.

>>> import coordinax.distances as cxd
>>> from coordinax.astro import Parallax
>>> p = Parallax(1, "mas")
>>> cxd.Distance.from_(p).uconvert("pc").round(2)
Distance(1000., 'pc')
from_(cls: type[DistanceModulus], p: Parallax, /, **kw: Any) DistanceModulus
Parameters:
Return type:

AbstractQuantity

Compute distance modulus from parallax.

>>> from coordinax.astro import DistanceModulus, Parallax
>>> p = Parallax(1, "mas")
>>> DistanceModulus.from_(p)
DistanceModulus(10., 'mag')
from_(cls: type[Parallax], dm: DistanceModulus, /, **kw: Any) Parallax
Parameters:
Return type:

AbstractQuantity

Convert distance modulus to parallax.

>>> import unxt as u
>>> from coordinax.astro import DistanceModulus, Parallax
>>> dm = DistanceModulus(10, "mag")
>>> Parallax.from_(dm).uconvert("mas").round(2)
Parallax(1., 'mas')
Parameters:
Return type:

AbstractQuantity

property mT: AbstractQuantity#

Matrix transpose of the array.

Examples

>>> import unxt as u
>>> q = u.Quantity([[0, 1], [1, 2]], "m")
>>> q.mT
Quantity(Array([[0, 1],
                          [1, 2]], dtype=int32), unit='m')
max(*args: Any, **kwargs: Any)#

Return the maximum value.

Examples

>>> import unxt as u
>>> q = u.Quantity([1, 2, 3], "m")
>>> q.max()
Quantity(Array(3, dtype=int32), unit='m')
Parameters:
Return type:

AbstractQuantity

mean(*args: Any, **kwargs: Any)#

Return the mean value.

Examples

>>> import unxt as u
>>> q = u.Quantity([1, 2, 3], "m")
>>> q.mean()
Quantity(Array(2., dtype=float32), unit='m')
Parameters:
Return type:

AbstractQuantity

min(*args: Any, **kwargs: Any)#

Return the minimum value.

Examples

>>> import unxt as u
>>> q = u.Quantity([1, 2, 3], "m")
>>> q.min()
Quantity(Array(1, dtype=int32), unit='m')
Parameters:
Return type:

AbstractQuantity

ravel()#

Return a flattened version of the array.

Return type:

AbstractQuantity

Examples

>>> import unxt as u
>>> q = u.Quantity([[1, 2], [3, 4]], "m")
>>> q.ravel()
Quantity(Array([1, 2, 3, 4], dtype=int32), unit='m')
reshape(*args: Any, order: str = 'C')#

Return a reshaped version of the array.

Examples

>>> import unxt as u
>>> q = u.Quantity([1, 2, 3, 4], "m")
>>> q.reshape(2, 2)
Quantity(Array([[1, 2],
                          [3, 4]], dtype=int32), unit='m')
Parameters:
Return type:

AbstractQuantity

round(*args: Any, **kwargs: Any)#

Round the array to the given number of decimals.

Examples

>>> import unxt as u
>>> q = u.Quantity([1.1, 2.2, 3.3], "m")
>>> q.round(0)
Quantity(Array([1., 2., 3.], dtype=float32), unit='m')
Parameters:
Return type:

AbstractQuantity

property sharding: Any#

Return the sharding configuration of the array.

Examples

>>> import unxt as u
>>> q = u.Quantity([1, 2, 3], "m")
>>> q.sharding
SingleDeviceSharding(device=..., memory_kind=...)
property size: int#

Total number of elements.

Examples

>>> import unxt as u
>>> q = u.Quantity([1, 2, 3], "m")
>>> q.size
3
squeeze(*args: Any, **kwargs: Any)#

Return the array with all single-dimensional entries removed.

Examples

>>> import unxt as u
>>> q = u.Quantity([[[1], [2], [3]]], "m")
>>> q.squeeze()
Quantity(Array([1, 2, 3], dtype=int32), unit='m')
Parameters:
Return type:

AbstractQuantity

to(u: Any, /)#

Convert the quantity to the given units.

See unxt.quantity.AbstractQuantity.uconvert.

Examples

>>> from unxt import Quantity
>>> q = Quantity(1, "m")
>>> q.to("cm")
Quantity(Array(100., dtype=float32, ...), unit='cm')
Parameters:

u (Any)

Return type:

AbstractQuantity

to_device(device: None | Device = None)#

Move the array to a new device.

Examples

>>> import unxt as u
>>> q = u.Quantity(1, "m")
>>> q.to_device(None)
Quantity(Array(1, dtype=int32, weak_type=True), unit='m')
Parameters:

device (None | Device)

Return type:

AbstractQuantity

to_value(u: Any, /)#

Return the value in the given units.

See unxt.AbstractQuantity.ustrip.

Examples

>>> from unxt import Quantity
>>> q = Quantity(1, "m")
>>> q.to_value("cm")
Array(100., dtype=float32, weak_type=True)
Parameters:

u (Any)

Return type:

Union[Array, ndarray, bool, number, bool, int, float, complex]

uconvert(u: Any, /)#

Convert the quantity to the given units.

See also

None

convert a quantity to a new unit.

Examples

>>> import unxt as u
>>> q = u.Quantity(1, "m")
>>> q.uconvert("cm")
Quantity(Array(100., dtype=float32, ...), unit='cm')
Parameters:

u (Any)

Return type:

AbstractQuantity

ustrip(u: Any, /)#

Return the value in the given units.

See also

None

strip the units from a quantity.

Examples

>>> import unxt as u
>>> q = u.Quantity(1, "m")
>>> q.ustrip("cm")
Array(100., dtype=float32, weak_type=True)
Parameters:

u (Any)

Return type:

Array

final class coordinax.internal.UnitsMatrix(iterable: Any, /)#

Bases: object

Immutable, hashable unit structure for QMatrix.

UnitsMatrix wraps a numpy object array (dtype=object) of ~unxt.AbstractUnit elements. Only 1-D and 2-D structures are accepted.

The class supports tuple-style indexing, iteration, to_tuple(), and to_string(). It is not a subclass of astropy.units.StructuredUnit; bidirectional converters to/from StructuredUnit are provided in coordinax.interop.astropy.

Hashability is achieved via hash(self.to_tuple()), so the underlying AbstractUnit objects must themselves be hashable (they are).

For 1D: UnitsMatrix(("m", "s", "kg")) For 2D: UnitsMatrix((("m", "s"), ("kg", "rad")))

Examples

>>> import unxt as u
>>> from coordinax.internal import UnitsMatrix

1D case:

>>> units_1d = UnitsMatrix(("m", "s", "kg"))
>>> units_1d.shape
(3,)
>>> units_1d[0]
Unit("m")
>>> units_1d.to_string()
'(m, s, kg)'

2D case:

>>> units_2d = UnitsMatrix((("m", "s"), ("kg", "rad")))
>>> units_2d.shape
(2, 2)
>>> units_2d[0, 1]
Unit("s")
>>> units_2d.to_string()
'((m, s), (kg, rad))'
Parameters:

iterable (Any)

property shape: tuple[int, ...]#

Shape of the N-D unit structure.

property ndim: int#

Number of dimensions.

property T: UnitsMatrix#

Compute the all-axis units array transpose.

Examples

>>> from coordinax.internal import UnitsMatrix
>>> units = UnitsMatrix(("m", "s"))
>>> units.T
UnitsMatrix("(m, s)")
>>> units = UnitsMatrix((("m", "s"), ("kg", "rad")))
>>> units.T
UnitsMatrix("((m, kg), (s, rad))")
>>> units = UnitsMatrix((("m", "s", "kg"), ("Hz", "candela", "km")))
>>> units.T
UnitsMatrix("((m, Hz), (s, cd), (kg, km))")
inverse()#

Inverse unit structure โ€” each unit raised to the power -1.

For a 1-D (diagonal) UnitsMatrix the inversion is done entry-by-entry in O(n), providing a speedup over the general 2-D case. For a 2-D UnitsMatrix with a uniform unit (all entries equal) the reciprocal is computed once and broadcast in O(1); mixed-unit 2-D structures fall back to an element-wise O(nm) loop.

Return type:

UnitsMatrix

Examples

>>> from coordinax.internal import UnitsMatrix

1-D (diagonal) case โ€” element-wise reciprocal:

>>> UnitsMatrix(("m2", "s2")).inverse()
UnitsMatrix("(1 / m2, 1 / s2)")

2-D uniform-unit case:

>>> UnitsMatrix((("m2", "m2"), ("m2", "m2"))).inverse()
UnitsMatrix("((1 / m2, 1 / m2), (1 / m2, 1 / m2))")

2-D mixed-unit case:

>>> UnitsMatrix((("m2", "s2"), ("s2", "rad2"))).inverse()
UnitsMatrix("((1 / m2, 1 / s2), (1 / s2, 1 / rad2))")
to_tuple()#

Convert to a nested tuple of ~unxt.AbstractUnit objects.

Return type:

Union[Unit, UnitBase, CompositeUnit, tuple[Union[TypeVar(T), tuple[NestedTuple[T], ...]], ...]]

Examples

>>> from coordinax.internal import UnitsMatrix
>>> import unxt as u
>>> UnitsMatrix(("m", "s")).to_tuple()
(Unit("m"), Unit("s"))
to_string()#

Return a human-readable string representation of the unit structure.

Return type:

str

Examples

>>> from coordinax.internal import UnitsMatrix
>>> UnitsMatrix(("m", "s", "kg")).to_string()
'(m, s, kg)'
>>> UnitsMatrix((("m", "s"), ("kg", "rad"))).to_string()
'((m, s), (kg, rad))'
coordinax.internal.tree_cast_int_bool_to_float(tree: PyTree[jaxtyping.Bool[Array, '...'] | jaxtyping.Int[Array, '...'] | jaxtyping.Float[Array, '...'] | jaxtyping.Complex[Array, '...']], /)#

Tree-map integer/bool leaves to the configured default float dtype.

This intentionally does not cast complex leaves, which prevents silent imaginary-part loss.

>>> import jax.numpy as jnp
>>> from coordinax.internal import tree_cast_int_bool_to_float
>>> x = {
...     "i": jnp.array([1, 2], dtype=jnp.int32),
...     "b": jnp.array([True, False], dtype=jnp.bool_),
...     "f": jnp.array([1.5], dtype=jnp.float32),
...     "c": jnp.array([1 + 2j], dtype=jnp.complex64),
... }
>>> tree_cast_int_bool_to_float(x)
{'b': Array([1., 0.], dtype=float64),
 'c': Array([1.+2.j], dtype=complex64),
 'f': Array([1.5], dtype=float32),
 'i': Array([1., 2.], dtype=float64)}
Parameters:

tree (']])

Return type:

']]

coordinax.internal.pack_uniform_unit(p: dict[str, Any], /, keys: tuple[str, ...])#
Overloads:
  • p (dict[CKey, Any]), keys (tuple[CKey, โ€ฆ]) โ†’ tuple[jnp.ndarray, u.AbstractUnit]

  • p (dict[CKey, ArrayLike]), keys (tuple[CKey, โ€ฆ]) โ†’ tuple[jnp.ndarray, None]

Parameters:
Return type:

tuple[Array, Unit | UnitBase | CompositeUnit | None]

Pack a component dictionary into one array using a shared unit.

The first requested key chooses the reference unit when the data is quantity-valued. Remaining entries are converted into that unit before the values are stacked along the trailing axis. If the entries are plain arrays or scalars, the returned unit is None.

Parameters:
  • p (dict[str, Any]) โ€“ Component dictionary to pack.

  • keys (tuple[str, ...]) โ€“ Ordered keys to extract and stack.

Returns:

Packed values together with the shared unit used for conversion, or None for unitless inputs.

Return type:

tuple[jnp.ndarray, u.AbstractUnit | None]

Examples

>>> import unxt as u
>>> from coordinax.internal import pack_uniform_unit
>>> p = {"x": u.Q(1.0, "km"), "y": u.Q(200.0, "m")}
>>> vals, unit = pack_uniform_unit(p, ("x", "y"))
>>> unit
Unit("km")
coordinax.internal.cdict_units(p: dict[str, Any], keys: tuple[str, ...], /)#

Extract per-key units from a component dictionary.

Non-quantity entries yield None, so the output tuple can be used for heterogeneous dictionaries containing both quantity and non-quantity data.

>>> import unxt as u
>>> d = {'x': u.Q(1.0, 'm'), 'y': 2.0, 'z': u.Q(3.0, 'kg')}
>>> cdict_units(d, ('x', 'y', 'z'))
(Unit("m"), None, Unit("kg"))
Parameters:
Return type:

tuple[Unit | UnitBase | CompositeUnit | None, ...]

coordinax.internal.pack_nonuniform_unit(p: dict[str, Any], /, keys: tuple[str, ...])#

Pack a component dictionary into an array with per-component units.

Unlike pack_uniform_unit, this helper does not choose a single reference unit. Each requested component is stripped in its own native unit and the resulting unit tuple is returned alongside the stacked values.

This is the appropriate packing mode when different coordinates naturally have different physical dimensions or when preserving the original unit of each component is important.

Parameters:
Return type:

tuple[Array, tuple[Unit | UnitBase | CompositeUnit | None, ...]]

coordinax.internal.pack_with_usys(p: dict[str, Any], /, keys: tuple[str, ...], usys: AbstractUnitSystem)#

Pack a component dictionary into an array with per-component units.

Parameters:
Return type:

tuple[Array, tuple[Unit | UnitBase | CompositeUnit, ...]]

coordinax.internal.pack_to_qmatrix(p: dict[str, Any], /, keys: tuple[str, ...] | None = None)#

Pack a component dictionary into a QMatrix or plain Array.

Components are ordered according to keys. If the values are {class}`~unxt.AbstractQuantity`, a 1-D {class}`~coordinax.internal.QMatrix` is returned with per-component units. If the values are plain arrays, a stacked JAX array is returned.

Parameters:
  • p (dict[str, Any]) โ€“ Component dictionary to pack.

  • keys (tuple[str, ...] | None) โ€“ Ordered keys to extract and stack.

Returns:

Packed representation of the component dictionary.

Return type:

Array | QMatrix

Examples

>>> import jax.numpy as jnp
>>> import unxt as u
>>> from coordinax.internal import pack_to_qmatrix
>>> p = {"x": u.Q(1.0, "km"), "y": u.Q(2.0, "km"), "z": u.Q(3.0, "km")}
>>> pack_to_qmatrix(p, ("x", "y", "z"))
QMatrix([1., 2., 3.], '(km, km, km)')
coordinax.internal.pos_named_objs(pairs: Iterable[tuple[str, Any]], pos_names: tuple[str, ...] | list[str], fields: dict[str, Any], /, *, hide_defaults: bool = True, **kw: Any)#

Render positional fields first, then non-default named fields.

Parameters:
  • pairs (Iterable[tuple[str, Any]]) โ€“ Field name-value pairs (e.g., from field_items(self)).

  • pos_names (tuple[str, ...] | list[str]) โ€“ Names of fields to render positionally (in order), listed first.

  • fields (dict[str, Any]) โ€“ Field descriptors (self.__dataclass_fields__). Used to read default values for filtering optional named fields.

  • hide_defaults (bool) โ€“ If True (default), named fields whose value equals the fieldโ€™s default are omitted from the output.

  • **kw (Any) โ€“ Extra keyword arguments forwarded to wadler_lindig.pdoc() and wadler_lindig.named_objs().

Returns:

Positional docs followed by named non-default docs.

Return type:

list[AbstractDoc]

coordinax.internal.jax_scalar_handler(obj: Any, /)#

Handler to render concrete 0-d JAX arrays as Python scalars.

Pass this as custom=jax_scalar_handler to wadler_lindig.pdoc(), wadler_lindig.pformat(), or wadler_lindig.pprint() so that concrete 0-d JAX arrays are displayed as plain Python numbers (10.0) rather than the default array-summary form (f64[](jax)).

Rendering rules:

  • JAX Tracer (inside jax.jit): returns None so that wadler_lindig falls back to its default shape/dtype summary.

  • Concrete 0-d array (has .item()): returns a doc for the plain Python scalar.

  • Everything else: returns None (default behaviour).

Examples

>>> import jax.numpy as jnp
>>> import wadler_lindig as wl
>>> from coordinax.internal import jax_scalar_handler

Concrete scalars inside a dict are shown as Python numbers:

>>> d = {"x": jnp.array(10.0), "y": jnp.array(0.0)}
>>> wl.pformat(d, custom=jax_scalar_handler)
"{'x': 10.0, 'y': 0.0}"

Callables are unaffected (custom returns None, wl uses default):

>>> f = lambda t: {"x": t}
>>> wl.pformat(f, custom=jax_scalar_handler)
'<function <lambda>>'
Parameters:

obj (Any)

Return type:

AbstractDoc | None

coordinax.internal.det(x: Array, /)#

Compute the determinant of a square matrix via the det_p primitive.

Delegates to det_p, a custom JAX primitive that supports JIT, forward and reverse differentiation, and batching (vmap).

For plain arrays the result is a bare Array. For QMatrix inputs the Quax dispatch intercepts the call (see _det_p_QMatrix) and returns a AbstractQuantity.

Parameters:

x (Array) โ€“ Square matrix or batch of square matrices.

Returns:

Determinant of each matrix.

Return type:

Array

Examples

>>> import jax.numpy as jnp
>>> from coordinax._src.internal.quantity_matrix import det

Plain 2x2 diagonal matrix:

>>> det(jnp.array([[2.0, 0.0], [0.0, 3.0]]))
Array(6., dtype=float64)

Under JIT:

>>> import jax
>>> jax.jit(det)(jnp.array([[2.0, 0.0], [0.0, 3.0]]))
Array(6., dtype=float64)

Gradient (via reverse-mode autodiff):

>>> jax.grad(det)(jnp.array([[2.0, 0.0], [0.0, 3.0]]))
Array([[3., 0.],
       [0., 2.]], dtype=float64)

Batched (vmap):

>>> A = jnp.stack([jnp.diag(jnp.array([2.0, 3.0])),
...                jnp.diag(jnp.array([4.0, 5.0]))])
>>> jax.vmap(det)(A)
Array([ 6., 20.], dtype=float64)
coordinax.internal.inv(x: Array, /)#

Compute the matrix inverse of a square matrix via the inv_p primitive.

Delegates to inv_p, a custom JAX primitive that supports JIT, forward and reverse differentiation, and batching (vmap).

For plain arrays the result is a bare Array. For QMatrix inputs the Quax dispatch intercepts the call (see _inv_p_QMatrix) and returns a QMatrix with reciprocal units.

Parameters:

x (Array) โ€“ Square matrix or batch of square matrices.

Returns:

Matrix inverse of each square matrix.

Return type:

Array

Examples

>>> import jax.numpy as jnp
>>> from coordinax._src.internal.quantity_matrix import inv

Plain 2x2 diagonal matrix:

>>> inv(jnp.array([[2.0, 0.0], [0.0, 4.0]]))
Array([[0.5 , 0.  ],
       [0.  , 0.25]], dtype=float64)

Under JIT:

>>> import jax
>>> jax.jit(inv)(jnp.array([[2.0, 0.0], [0.0, 4.0]]))
Array([[0.5 , 0.  ],
       [0.  , 0.25]], dtype=float64)

Gradient (via reverse-mode autodiff) โ€” returns a rank-4 Jacobian:

>>> jac = jax.jacobian(inv)(jnp.array([[2.0, 0.0], [0.0, 4.0]]))
>>> jac.shape
(2, 2, 2, 2)

Batched (vmap):

>>> A = jnp.stack([jnp.diag(jnp.array([2.0, 4.0])),
...                jnp.diag(jnp.array([1.0, 2.0]))])
>>> jax.vmap(inv)(A)
Array([[[0.5 , 0.  ],
        [0.  , 0.25]],

       [[1.  , 0.  ],
        [0.  , 0.5 ]]], dtype=float64)