JAX API#
The sphericart.jax module aims to provide a functional-style and
JAX-friendly framework. As a result, it does not follow the same syntax as
the Python and PyTorch SphericalHarmonics
classes. Instead, it
provides a function that is fully compatible with JAX primitives
(jax.grad
, jax.jit
, and so on).
Depending on the device the array is stored on, as well as its dtype, the calculations will be performed using 32- or 64- bits floating point arythmetics, and using the CPU or CUDA implementation.
- sphericart.jax.spherical_harmonics(xyz, l_max, normalized=False)#
Computes the Spherical harmonics and their derivatives within the JAX framework.
This function supports
jit
,vmap
, and up to two rounds of forward and/or backward automatic differentiation (grad
,jacfwd
,jacrev
,hessian
, …). For the moment, it does not supportpmap
.Note that the
l_max
andnormalized
arguments (positions 1 and 2 in the signature) should be tagged as static when jit-ing the function:>>> import jax >>> import sphericart.jax >>> jitted_sph_function = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2))
Parameters#
- xyzjax array […, 3]
single vector or set of vectors in 3D. All dimensions are optional except for the last
- l_maxint
maximum order of the spherical harmonics (included)
- normalizedbool
whether the function computes Cartesian solid harmonics (
normalized=False
, default) or normalized spherical harmonicsi (normalized=True
)
Returns#
- jax array […, (l_max+1)**2]
Spherical harmonics expansion of xyz