PyTorch API¶
The classes for computing spherical harmonics using a
torch
-compatible framework follow a similar syntax to
the Python versions sphericart.SphericalHarmonics
and sphericart.SolidHarmonics
, while inheriting
from torch.nn.Module
.
Depending on the device
the tensor is
stored on and its dtype
, the calculations will be performed
using 32- or 64- bits floating point arythmetics, and
using the CPU or CUDA implementation.
- class sphericart.torch.SphericalHarmonics(l_max: int, backward_second_derivatives: bool = False)¶
Spherical harmonics calculator, which computes the real spherical harmonics \(Y^m_l\) up to degree
l_max
. The calculated spherical harmonics are consistent with the definition of real spherical harmonics from Wikipedia.This class can be used similarly to
sphericart.SphericalHarmonics
(its Python/NumPy counterpart). If the class is called directly, the outputs support single and double backpropagation.>>> xyz = xyz.detach().clone().requires_grad_() >>> sh = sphericart.torch.SphericalHarmonics(l_max=8) >>> sh_values = sh(xyz) # or sh.compute(xyz) >>> sh_values.sum().backward() >>> torch.allclose(xyz.grad, sh_grads.sum(axis=-1)) True
By default, only single backpropagation with respect to
xyz
is enabled (this includes mixed second derivatives wherexyz
appears as only one of the differentiation steps). To activate support for double backpropagation with respect toxyz
, please setbackward_second_derivatives=True
at class creation. Warning: ifbackward_second_derivatives
is not set toTrue
and double differentiation with respect toxyz
is requested, the results may be incorrect, but a warning will be displayed. This is necessary to provide optimal performance for both use cases. In particular, the following will happen:when using
torch.autograd.grad
as the second backpropagation step, a warning will be displayed and torch will raise an error.when using
torch.autograd.grad
withallow_unused=True
as the second backpropagation step, the results will be incorrect and only a warning will be displayed.when using
backward
as the second backpropagation step, the results will be incorrect and only a warning will be displayed.when using
torch.autograd.functional.hessian
, the results will be incorrect and only a warning will be displayed.
Alternatively, the class allows to return explicit forward gradients and/or Hessians of the spherical harmonics. For example:
>>> import torch >>> import sphericart.torch >>> sh = sphericart.torch.SphericalHarmonics(l_max=8) >>> xyz = torch.rand(size=(10,3)) >>> sh_values, sh_grads = sh.compute_with_gradients(xyz) >>> sh_grads.shape torch.Size([10, 3, 81])
This class supports TorchScript.
- Parameters:
l_max – the maximum degree of the spherical harmonics to be calculated
backward_second_derivatives – if this parameter is set to
True
, second derivatives of the spherical harmonics are calculated and stored during forward calls tocompute
(provided thatxyz.requires_grad
isTrue
), making it possible to perform double reverse-mode differentiation with respect toxyz
. IfFalse
, only the first derivatives will be computed and only a single reverse-mode differentiation step will be possible with respect toxyz
.
- Returns:
a calculator, in the form of a SphericalHarmonics object
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(xyz: Tensor) Tensor ¶
Calculates the spherical harmonics for a set of 3D points.
The coordinates should be stored in the
xyz
array. Ifxyz
hasrequires_grad = True
it stores the forward derivatives which are then used in the backward pass. The type of the entries ofxyz
determines the precision used, and the device the tensor is stored on determines whether the CPU or CUDA implementation is used for the calculation backend. It always supports single reverse-mode differentiation, as well as double reverse-mode differentiation ifbackward_second_derivatives
was set toTrue
during class creation.- Parameters:
xyz – The Cartesian coordinates of the 3D points, as a torch.Tensor with shape
(n_samples, 3)
.- Returns:
A tensor of shape
(n_samples, (l_max+1)**2)
containing all the spherical harmonics up to degree l_max in lexicographic order. For example, ifl_max = 2
, The last axis will correspond to spherical harmonics with(l, m) = (0, 0), (1, -1), (1, 0), (1, 1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)
, in this order.
- compute_with_gradients(xyz: Tensor) Tuple[Tensor, Tensor] ¶
Calculates the spherical harmonics for a set of 3D points, and also returns the forward-mode derivatives.
The coordinates should be stored in the
xyz
array. The type of the entries ofxyz
determines the precision used, and the device the tensor is stored on determines whether the CPU or CUDA implementation is used for the calculation backend. Reverse-mode differentiation is not supported for this function.- Parameters:
xyz – The Cartesian coordinates of the 3D points, as a torch.Tensor with shape
(n_samples, 3)
.- Returns:
A tuple that contains:
A
(n_samples, (l_max+1)**2)
tensor containing all the spherical harmonics up to degreel_max
in lexicographic order. For example, ifl_max = 2
, The last axis will correspond to spherical harmonics with(l, m) = (0, 0), (1, -1), (1, 0), (1, 1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)
, in this order.A tensor of shape
(n_samples, 3, (l_max+1)**2)
containing all the spherical harmonics’ derivatives up to degreel_max
. The last axis is organized in the same way as in the spherical harmonics return array, while the second-to-last axis refers to derivatives in the the x, y, and z directions, respectively.
- compute_with_hessians(xyz: Tensor) Tuple[Tensor, Tensor, Tensor] ¶
Calculates the spherical harmonics for a set of 3D points, and also returns the forward derivatives and second derivatives.
The coordinates should be stored in the
xyz
array. The type of the entries ofxyz
determines the precision used, and the device the tensor is stored on determines whether the CPU or CUDA implementation is used for the calculation backend. Reverse-mode differentiation is not supported for this function.- Parameters:
xyz – The Cartesian coordinates of the 3D points, as a
torch.Tensor
with shape(n_samples, 3)
.- Returns:
A tuple that contains:
A
(n_samples, (l_max+1)**2)
tensor containing all the spherical harmonics up to degreel_max
in lexicographic order. For example, ifl_max = 2
, The last axis will correspond to spherical harmonics with(l, m) = (0, 0), (1, -1), (1, 0), (1, 1), (2, -2), (2, -1), (2, 0), (2, 1), (2, 2)
, in this order.A tensor of shape
(n_samples, 3, (l_max+1)**2)
containing all the spherical harmonics’ derivatives up to degreel_max
. The last axis is organized in the same way as in the spherical harmonics return array, while the second-to-last axis refers to derivatives in the the x, y, and z directions, respectively.A tensor of shape
(n_samples, 3, 3, (l_max+1)**2)
containing all the spherical harmonics’ second derivatives up to degreel_max
. The last axis is organized in the same way as in the spherical harmonics return array, while the two intermediate axes represent the hessian dimensions.
- omp_num_threads()¶
Returns the number of threads available for calculations on the CPU.
- l_max()¶
Returns the maximum angular momentum setting for this calculator.
- class sphericart.torch.SolidHarmonics(l_max: int, backward_second_derivatives: bool = False)¶
Solid harmonics calculator, up to degree
l_max
.This class computes the solid harmonics, a non-normalized form of the real spherical harmonics, i.e. \(r^lY^m_l\). These scaled spherical harmonics are polynomials in the Cartesian coordinates of the input points.
The usage of this class is identical to
sphericart.SphericalHarmonics
.- Parameters:
l_max – the maximum degree of the spherical harmonics to be calculated
backward_second_derivatives – if this parameter is set to
True
, second derivatives of the spherical harmonics are calculated and stored during forward calls tocompute
(provided thatxyz.requires_grad
isTrue
), making it possible to perform double reverse-mode differentiation with respect toxyz
. IfFalse
, only the first derivatives will be computed and only a single reverse-mode differentiation step will be possible with respect toxyz
.
- Returns:
a calculator, in the form of a SolidHarmonics object
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- omp_num_threads()¶
Returns the number of threads available for calculations on the CPU.
- l_max()¶
Returns the maximum angular momentum setting for this calculator.
The implementation also contains a couple of utility functions
to facilitate the integration of sphericart
into code using
`e3nn
.
- sphericart.torch.e3nn_spherical_harmonics(l_list: List[int] | int, x: Tensor, normalize: bool | None = False, normalization: str | None = 'integral') Tensor ¶
Computes spherical harmonics with an interface similar to the e3nn package.
Provides an interface that is similar to
e3nn.o3.spherical_harmonics()
but usesSphericalHarmonics
for the actual calculation. Uses the same ordering of the [x,y,z] axes, and supports the same options for input and harmonics normalization ase3nn
. However, it does not support defining the irreps through ae3nn.o3._irreps.Irreps
or a string specification, but just as a single integer or a list of integers.- Parameters:
l_list – Either a single integer or a list of integers specifying which \(Y^m_l\) should be computed. All values up to the maximum l value are computed, so this may be inefficient for use cases requiring a single, or few, angular momentum channels.
x – A
torch.Tensor
containing the coordinates, in the same format expected by thee3nn
function.normalize – Flag specifying whether the input positions should be normalized (resulting in the computation of the spherical harmonics \(Y^m_l\)), or whether the function should compute the solid harmonics \(r^lY^m_l\).
normalization – String that can be “integral”, “norm”, “component”, that controls a further scaling of the \(Y^m_l\). See the documentation of
e3nn.o3.spherical_harmonics()
for a detailed explanation of the different conventions.
- sphericart.torch.patch_e3nn(e3nn_module: ModuleType) None ¶
Patches the
e3nn
module so thatsphericart_torch.e3nn_spherical_harmonics()
is called in lieu of the built-in function.- Parameters:
e3nn_module – The alias that has been chosen for the e3nn module, usually just
e3nn
.
- sphericart.torch.unpatch_e3nn(e3nn_module: ModuleType) None ¶
Restore the original
spherical_harmonics
function in thee3nn
module.