JAX API¶
The sphericart.jax module aims to provide a functional-style and
JAX-friendly framework. As a result, it does not follow the same syntax as
the Python and PyTorch SphericalHarmonics
classes. Instead, it
provides a function that is fully compatible with JAX primitives
(jax.grad
, jax.jit
, and so on).
Depending on the device the array is stored on, as well as its dtype, the calculations will be performed using 32- or 64- bits floating point arythmetics, and using the CPU or CUDA implementation.
- sphericart.jax.spherical_harmonics(xyz: Array, l_max: int)¶
Computes the spherical harmonics and their derivatives within the JAX framework.
The definition of the real spherical harmonics is consistent with the Wikipedia spherical harmonics page.
Note that the
l_max
argument (position 1 in the signature) should be tagged as static when jit-ing the function:>>> import jax >>> import sphericart.jax >>> jitted_sph_fn = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=1)
- Parameters:
xyz – single vector or set of vectors in 3D. All dimensions are optional except for the last. Shape
[..., 3]
.l_max – the maximum degree of the spherical harmonics to be calculated (included)
- Returns:
Spherical harmonics expansion of
xyz
. Shape[..., (l_max+1)**2]
. The last dimension is organized in lexicographic order. For example, ifl_max = 2
, The last axis will correspond to spherical harmonics with(l, m) = (0, 0), (1, -1), (1, 0), (1, 1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)
, in this order.
- sphericart.jax.solid_harmonics(xyz: Array, l_max: int)¶
Same as spherical_harmonics, but computes the solid harmonics instead.
These are a non-normalized form of the real spherical harmonics, i.e. \(r^lY^m_l\). These scaled spherical harmonics are polynomials in the Cartesian coordinates of the input points, and they are therefore less expensive to compute.