JAX API

The sphericart.jax module aims to provide a functional-style and JAX-friendly framework. As a result, it does not follow the same syntax as the Python and PyTorch SphericalHarmonics classes. Instead, it provides a function that is fully compatible with JAX primitives (jax.grad, jax.jit, and so on).

Depending on the device the array is stored on, as well as its dtype, the calculations will be performed using 32- or 64- bits floating point arythmetics, and using the CPU or CUDA implementation.

sphericart.jax.spherical_harmonics(xyz: Array, l_max: int)

Computes the spherical harmonics and their derivatives within the JAX framework.

The definition of the real spherical harmonics is consistent with the Wikipedia spherical harmonics page.

Note that the l_max argument (position 1 in the signature) should be tagged as static when jit-ing the function:

>>> import jax
>>> import sphericart.jax
>>> jitted_sph_fn = jax.jit(sphericart.jax.spherical_harmonics, static_argnums=1)
Parameters:
  • xyz – single vector or set of vectors in 3D. All dimensions are optional except for the last. Shape [..., 3].

  • l_max – the maximum degree of the spherical harmonics to be calculated (included)

Returns:

Spherical harmonics expansion of xyz. Shape [..., (l_max+1)**2]. The last dimension is organized in lexicographic order. For example, if l_max = 2, The last axis will correspond to spherical harmonics with (l, m) = (0, 0), (1, -1), (1, 0), (1, 1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2), in this order.

sphericart.jax.solid_harmonics(xyz: Array, l_max: int)

Same as spherical_harmonics, but computes the solid harmonics instead.

These are a non-normalized form of the real spherical harmonics, i.e. \(r^lY^m_l\). These scaled spherical harmonics are polynomials in the Cartesian coordinates of the input points, and they are therefore less expensive to compute.