Source code for qmctorch.wavefunction.jastrows.elec_elec.kernels.pade_jastrow_kernel

import torch
from torch import nn

from .....utils import register_extra_attributes
from .jastrow_kernel_electron_electron_base import JastrowKernelElectronElectronBase


[docs]class PadeJastrowKernel(JastrowKernelElectronElectronBase): def __init__(self, nup, ndown, cuda, w=1.): """Computes the Simple Pade-Jastrow factor .. math:: B_{ij} = \\frac{w_0 r_{ij}}{1 + w r_{ij}} where :math:`w_0` equals 0.5 for parallel spin and 0.25 for antiparallel spin Args: nup (int): number of spin up electons ndow (int): number of spin down electons cuda (bool): Turns GPU ON/OFF. w (float, optional): Value of the variational parameter. Defaults to 1. """ super().__init__(nup, ndown, cuda) self.weight = nn.Parameter(torch.as_tensor([w]), requires_grad=True) register_extra_attributes(self, ['weight']) self.static_weight = self.get_static_weight() self.requires_autograd = False
[docs] def get_static_weight(self): """Get the matrix of static weights Returns: torch.tensor: matrix of the static weights """ bup = torch.cat((0.25 * torch.ones(self.nup, self.nup), 0.5 * torch.ones(self.nup, self.ndown)), dim=1) bdown = torch.cat((0.5 * torch.ones(self.ndown, self.nup), 0.25 * torch.ones(self.ndown, self.ndown)), dim=1) static_weight = torch.cat((bup, bdown), dim=0).to(self.device) mask_tri_up = torch.triu(torch.ones_like( static_weight), diagonal=1).type(torch.BoolTensor).to(self.device) static_weight = static_weight.masked_select(mask_tri_up) return static_weight
[docs] def forward(self, r): """ Get the jastrow kernel. .. math:: B_{ij} = \\frac{w_0 r_{i,j}}{1+w r_{i,j}} Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec x Nelec Returns: torch.tensor: matrix of the jastrow kernels Nbatch x Nelec x Nelec """ return self.static_weight * r / (1.0 + self.weight * r)
[docs] def compute_derivative(self, r, dr): """Get the elements of the derivative of the jastrow kernels wrt to the first electrons .. math:: \\frac{d B_{ij}}{d k_i} = \\frac{d B_{ij}}{ d k_j } = - \\frac{d B_{ji}}{d k_i} .. math:: \\text{out}_{k,i,j} = A1 + A2 .. math:: A1_{kij} = w0 \\frac{dr_{ij}}{dk_i} \\frac{1}{1 + w r_{ij}} .. math:: A2_{kij} = - w0 w' r_{ij} \\frac{dr_{ij}}{dk_i} \\frac{1}{1 + w r_{ij}}^2 Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec x Nelec dr (torch.tensor): matrix of the derivative of the e-e distances Nbatch x Ndim x Nelec x Nelec Returns: torch.tensor: matrix fof the derivative of the jastrow elements Nbatch x Ndim x Nelec x Nelec """ r_ = r.unsqueeze(1) denom = 1. / (1.0 + self.weight * r_) a = self.static_weight * dr * denom b = -self.static_weight * self.weight * r_ * dr * denom**2 return (a + b)
[docs] def compute_second_derivative(self, r, dr, d2r): """Get the elements of the pure 2nd derivative of the jastrow kernels wrt to the first electron .. math :: \\frac{d^2 B_{ij}}{d k_i^2} = \\frac{d^2 B_{ij}}{d k_j^2} = \\frac{d^2 B_{ji}}{ d k_i^2} Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec x Nelec dr (torch.tensor): matrix of the derivative of the e-e distances Nbatch x Ndim x Nelec x Nelec d2r (torch.tensor): matrix of the 2nd derivative of the e-e distances Nbatch x Ndim x Nelec x Nelec Returns: torch.tensor: matrix fof the pure 2nd derivative of the jastrow elements Nbatch x Ndim x Nelec x Nelec """ r_ = r.unsqueeze(1) denom = 1. / (1.0 + self.weight * r_) denom2 = denom**2 dr_square = dr * dr a = self.static_weight * d2r * denom b = -2 * self.static_weight * self.weight * dr_square * denom2 c = -self.static_weight * self.weight * r_ * d2r * denom2 d = 2 * self.static_weight * self.weight**2 * r_ * dr_square * denom**3 return a + b + c + d