PyTorch#

The PyTorch implementation follows closely the syntax and usage of the Python implementation, while also supporting backpropagation. The example shows how to compute gradients relative to the input coordinates by using backward(), and it also illustrates the computation of second derivatives by reverse-mode autodifferentiation. The sphericart.torch.SphericalHarmonics object can also be used inside a torch.nn.Module, that can then be compiled using torchscript.

import argparse

import numpy as np
import torch

import sphericart.torch


docstring = """
An example of the use of the PyTorch interface of the `sphericart` library.

Simply computes Cartesian spherical harmonics for the given parameters, for an
array of random 3D points, using both 32-bit and 64-bit arithmetics.
"""


class SHModule(torch.nn.Module):
    """Example of how to use SphericalHarmonics from within a
    `torch.nn.Module`"""

    def __init__(self, l_max, normalized=False):
        self._sph = sphericart.torch.SphericalHarmonics(l_max, normalized)
        super().__init__()

    def forward(self, xyz):
        sph = self._sph.compute(xyz)
        return sph


def sphericart_example(l_max=10, n_samples=10000, normalized=False):
    # `sphericart` provides a SphericalHarmonics object that initializes the
    # calculation and then can be called on any n x 3 arrays of Cartesian
    # coordinates. It computes _all_ SPH up to a given l_max, and can compute
    # scaled (default) and normalized (standard Ylm) harmonics.

    # ===== set up the calculation =====

    # initializes the Cartesian coordinates of points
    xyz = torch.randn((n_samples, 3), dtype=torch.float64, device="cpu")

    # float32 version
    xyz_f = xyz.clone().detach().type(torch.float32).to("cpu")

    # ===== API calls =====

    sh_calculator = sphericart.torch.SphericalHarmonics(l_max, normalized=normalized)

    # the interface allows to return directly the forward derivatives (up to second order),
    # similar to the Python version
    sh_sphericart = sh_calculator.compute(xyz)
    sh_sphericart, dsh_sphericart = sh_calculator.compute_with_gradients(xyz)
    (
        sh_sphericart,
        dsh_sphericart,
        ddsh_sphericart,
    ) = sh_calculator.compute_with_hessians(xyz)

    sh_sphericart_f, dsh_sphericart_f = sh_calculator.compute_with_gradients(xyz_f)

    # ===== check results =====

    print(
        "Float vs double relative error: %12.8e"
        % (
            np.linalg.norm(sh_sphericart.detach() - sh_sphericart_f.detach())
            / np.linalg.norm(sh_sphericart.detach())
        )
    )

    # ===== autograd integration =====

    # the implementation also supports backpropagation.
    # the input tensor must be tagged to have `requires_grad`
    xyz_ag = xyz.clone().detach().type(torch.float64).to("cpu").requires_grad_()
    sh_sphericart = sh_calculator.compute(xyz_ag)

    # then the spherical harmonics **but not their derivatives**
    # can be used with the usual PyTorch backward() workflow
    # nb: we sum only the even terms in the array because the total norm of a
    # Ylm is constant
    sph_norm = torch.sum(sh_sphericart[:, ::2] ** 2)
    sph_norm.backward()

    # checks the derivative is correct using the forward call
    delta = torch.norm(
        xyz_ag.grad
        - 2
        * torch.einsum("iaj,ij->ia", dsh_sphericart[:, :, ::2], sh_sphericart[:, ::2])
    ) / torch.norm(xyz_ag.grad)
    print(f"Check derivative difference (FW vs BW): {delta}")

    # double derivatives. In order to access them via backpropagation, an additional
    # flag must be specified at class instantiation:
    sh_calculator_2 = sphericart.torch.SphericalHarmonics(
        l_max, normalized=normalized, backward_second_derivatives=True
    )

    # double grad() call:
    xyz_ag2 = xyz[:5].clone().detach().type(torch.float64).to("cpu").requires_grad_()
    sh_sphericart_2 = sh_calculator_2.compute(xyz_ag2)
    sph_norm = torch.sum(sh_sphericart_2[:, ::2] ** 2)
    grad = torch.autograd.grad(sph_norm, xyz_ag2, retain_graph=True, create_graph=True)[
        0
    ]
    grad_grad = torch.autograd.grad(torch.sum(grad), xyz_ag2)[0]

    # hessian() call:
    xyz_ag2 = xyz[:5].clone().detach().type(torch.float64).to("cpu").requires_grad_()

    def func(xyz):
        sh_sphericart_2 = sh_calculator_2.compute(xyz)
        return torch.sum(sh_sphericart_2[:, ::2] ** 2)

    hessian = torch.autograd.functional.hessian(func, xyz_ag2)

    # ===== torchscript integration =====
    xyz_jit = xyz.clone().detach().type(torch.float64).to("cpu").requires_grad_()

    module = SHModule(l_max, normalized)

    # JIT compilation of the module
    script = torch.jit.script(module)
    sh_jit = script(xyz_jit)

    print(f"jit vs direct call: {torch.norm(sh_jit - sh_sphericart)}")

    # ===== GPU implementation ======

    if torch.cuda.is_available():
        xyz_cuda = xyz.clone().detach().type(torch.float64).to("cuda")

        sh_sphericart_cuda, dsh_sphericart_cuda = sh_calculator.compute_with_gradients(
            xyz_cuda
        )

        norm_dsph = torch.norm(dsh_sphericart_cuda.to("cpu") - dsh_sphericart)
        print(f"Check fw derivative difference CPU vs CUDA: {norm_dsph}")

        xyz_cuda_bw = (
            xyz.clone().detach().type(torch.float64).to("cuda").requires_grad_()
        )
        sh_sphericart_cuda_bw = sh_calculator.compute(xyz_cuda_bw)

        # then the spherical harmonics **but not their derivatives**
        # can be used with the usual PyTorch backward() workflow
        sph_norm_cuda = torch.sum(sh_sphericart_cuda_bw[:, ::2] ** 2)
        sph_norm_cuda.backward()

        delta = torch.norm(xyz_ag.grad - xyz_cuda_bw.grad.to("cpu")) / torch.norm(
            xyz_ag.grad
        )
        print(f"Check derivative difference CPU vs CUDA: {delta}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=docstring)

    parser.add_argument("-l", type=int, default=10, help="maximum angular momentum")
    parser.add_argument("-s", type=int, default=1000, help="number of samples")
    parser.add_argument(
        "--normalized",
        action="store_true",
        default=False,
        help="compute normalized spherical harmonics",
    )

    args = parser.parse_args()

    # Process everything.
    sphericart_example(args.l, args.s, args.normalized)