Working With Plain Arrays As Coordinates#
This tutorial covers using bare JAX arrays as coordinate data in coordinax. A plain array carries no units, no component names, no chart, and no frame β all that metadata must be supplied explicitly at every call site. This is the lowest level of the coordinax object tower, offering maximum performance and direct interop with raw JAX code.
You will learn how to:
Apply transforms to arrays with
act(requires chart, rep, and usys)Understand the
usys(unit system) requirementDecompose arrays into CDicts for chart conversion
Upgrade arrays to higher-level objects
Use arrays with JAX
Object Levels
Coordinax supports five levels of coordinate representation, each adding more metadata. This tutorial covers the bottom level β plain arrays.
Level |
Type |
See tutorial |
|---|---|---|
Coordinate |
|
|
Vector |
|
|
CDict |
|
|
Quantity |
|
|
Array |
|
this page |
Setup#
>>> import coordinax.main as cx
>>> import coordinax.charts as cxc
>>> import coordinax.frames as cxf
>>> import coordinax.representations as cxr
>>> import coordinax.transforms as cxfm
>>> import unxt as u
>>> import jax.numpy as jnp
>>> import jax
When To Use Plain Arrays#
Plain arrays are the right choice when:
You are in a performance-critical inner loop and cannot afford object-construction overhead.
You are interfacing with raw JAX code (e.g. existing numerical solvers, neural network outputs).
You already know the coordinate system and units and can supply them explicitly.
You are prototyping or teaching and want minimal boilerplate.
The trade-off: you must pass chart, representation, and a unit system to every coordinax function call. If you find yourself repeating the same metadata, upgrade to a Quantity or Vector.
Applying Transforms To Arrays#
Use cxfm.act() with explicit chart, representation, and unit system:
>>> usys = u.unitsystem("m", "s", "kg", "rad")
>>> rot90z = cxfm.Rotate.from_euler("z", u.Q(90, "deg"))
>>> arr = jnp.array([1.0, 0.0, 0.0])
>>> result = cxfm.act(rot90z, None, arr, cxc.cart3d, cxr.point, usys=usys)
>>> isinstance(result, jnp.ndarray)
True
The arguments:
rot90zβ the transformNoneβ time parameter (None for static transforms)arrβ the datacxc.cart3dβ the chart (coordinate system)cxr.pointβ the representation (point geometry)usys=usysβ the unit system (maps physical dimensions to concrete units)
Why usys Is Required#
Transforms like Translate store their offsets with units (e.g. Translate({"x": Q(1, "km"), ...})). A bare array has no units, so coordinax cannot add metres to a unitless number. The usys tells coordinax how to interpret the array: βthese numbers are in metres.β
>>> usys = u.unitsystem("km", "s", "kg", "rad")
>>> shift = cxfm.Translate.from_([1, 2, 3], "km")
>>> arr = jnp.array([0.0, 0.0, 0.0])
>>> result = cxfm.act(shift, None, arr, cxc.cart3d, cxr.point, usys=usys)
>>> isinstance(result, jnp.ndarray)
True
Decomposing Arrays Into CDicts#
For chart conversion, first convert the array into a CDict using cxc.cdict() with a chart and unit:
>>> d = cxc.cdict(jnp.array([1.0, 2.0, 3.0]), "km", cxc.cart3d)
>>> sorted(d.keys())
['x', 'y', 'z']
Then convert charts via pt_map:
>>> d_sph = cxc.pt_map(d, cxc.cart3d, cxc.sph3d)
>>> sorted(d_sph.keys())
['phi', 'r', 'theta']
The Upgrade Path#
Arrays sit at the bottom of the coordinax tower. You can upgrade step-by-step:
Array β Quantity#
Attach units:
>>> arr = jnp.array([1.0, 2.0, 3.0])
>>> q = u.Q(arr, "km")
>>> q.unit
Unit("km")
Quantity β Vector#
Attach chart and representation:
>>> v = cx.Point.from_(q)
>>> v.chart
Cart3D(M=Rn(3))
Vector β Coordinate#
Attach a reference frame:
>>> coord = cx.Point.from_(v, cxf.alice)
>>> coord.frame
Alice()
Shortcut: Array β Vector#
Skip the quantity step:
>>> v = cx.Point.from_([1, 2, 3], "km")
>>> v.chart
Cart3D(M=Rn(3))
JAX Integration#
Plain arrays are native JAX β jit, vmap, and grad work without any special handling:
>>> usys = u.unitsystem("m", "s", "kg", "rad")
>>> rot90z = cxfm.Rotate.from_euler("z", u.Q(90, "deg"))
>>> @jax.jit
... def rotate_array(x):
... return cxfm.act(rot90z, None, x, cxc.cart3d, cxr.point, usys=usys)
>>> arr = jnp.array([1.0, 0.0, 0.0])
>>> result = rotate_array(arr)
>>> isinstance(result, jnp.ndarray)
True
Comparison With Higher Levels#
Feature |
Array |
Quantity |
CDict |
Vector |
Coordinate |
|---|---|---|---|---|---|
Units |
β |
β |
β |
β |
β |
Component names |
β |
β |
β |
β |
β |
Chart |
β |
β |
β |
β |
β |
Representation |
β |
β |
β |
β |
β |
Frame |
β |
β |
β |
β |
β |
|
chart, rep, usys |
(chart, rep) |
chart, rep |
none |
none |
|
β |
β |
β |
|
both |
The further up the tower you go, the more metadata is attached and the fewer arguments you need to pass. Choose the level that matches your needs.