JAX API#

The sphericart.jax module aims to provide a functional-style and JAX-friendly framework. As a result, it does not follow the same syntax as the Python and PyTorch SphericalHarmonics classes. Instead, it provides a function that is fully compatible with JAX primitives (jax.grad, jax.jit, and so on).

Depending on the device the array is stored on, as well as its dtype, the calculations will be performed using 32- or 64- bits floating point arythmetics, and using the CPU or CUDA implementation.

sphericart.jax.spherical_harmonics(xyz, l_max, normalized=False)#

Computes the Spherical harmonics and their derivatives within the JAX framework.

This function supports jit, vmap, and up to two rounds of forward and/or backward automatic differentiation (grad, jacfwd, jacrev, hessian, …). For the moment, it does not support pmap.

Note that the l_max and normalized arguments (positions 1 and 2 in the signature) should be tagged as static when jit-ing the function:

>>> import jax
>>> import sphericart.jax
>>> jitted_sph_function = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=(1, 2))

Parameters#

xyzjax array […, 3]

single vector or set of vectors in 3D. All dimensions are optional except for the last

l_maxint

maximum order of the spherical harmonics (included)

normalizedbool

whether the function computes Cartesian solid harmonics (normalized=False, default) or normalized spherical harmonicsi (normalized=True)

Returns#

jax array […, (l_max+1)**2]

Spherical harmonics expansion of xyz