Source code for qmctorch.wavefunction.jastrows.elec_nuclei.kernels.quadratic_pade_kernel

import torch
from torch import nn

from .....utils import register_extra_attributes
from .jastrow_kernel_electron_nuclei_base import JastrowKernelElectronNucleiBase


[docs] class QuadraticPadeJastrowKernel(JastrowKernelElectronNucleiBase): def __init__( self, nup: int, ndown: int, atomic_pos: torch.Tensor, cuda: bool, a: float = 1.0, b: float = 1.0, c: float = 1.0, ) -> None: """ Initializes the Quadratic Pade-Jastrow kernel with the given parameters. Args: nup (int): Number of spin-up electrons. ndown (int): Number of spin-down electrons. atomic_pos (torch.Tensor): Tensor containing the atomic positions. cuda (bool): Flag to indicate whether to use CUDA for computations. a (float, optional): Initial value for the rweight parameter. Defaults to 1.0. b (float, optional): Initial value for the r2weight parameter. Defaults to 1.0. c (float, optional): Initial value for the weight parameter. Defaults to 1.0. """ super().__init__(nup, ndown, atomic_pos, cuda) self.rweight = nn.Parameter(torch.as_tensor([a]), requires_grad=True).to( self.device ) self.r2weight = nn.Parameter(torch.as_tensor([b]), requires_grad=True).to( self.device ) self.weight = nn.Parameter(torch.as_tensor([c]), requires_grad=True).to( self.device ) register_extra_attributes(self, ["weight", "r2weight", "rweight"]) self.requires_autograd = True
[docs] def forward(self, r: torch.Tensor) -> torch.Tensor: """ Computes the Quadratic Pade-Jastrow kernel. .. math:: J(r) = \\frac{a r + b r^2}{1 + c r} Args: r (torch.Tensor): Tensor containing the e-n distances. Returns: torch.Tensor: Tensor containing the computed Jastrow kernel. """ return (self.rweight * r + self.r2weight * r**2) / (1.0 + self.weight * r)