Source code for qmctorch.wavefunction.jastrows.elec_nuclei.jastrow_factor_electron_nuclei

import torch
from torch import nn
from ..distance.electron_nuclei_distance import ElectronNucleiDistance


[docs]class JastrowFactorElectronNuclei(nn.Module): def __init__(self, nup, ndown, atomic_pos, jastrow_kernel, kernel_kwargs={}, cuda=False): r"""Base class for two el-nuc jastrow of the form: .. math:: J = \prod_{a,i} \exp(A(r_{ai})) Args: nup (int): number of spin up electons ndow (int): number of spin down electons atomic_pos (tensor): positions of the atoms cuda (bool, optional): Turns GPU ON/OFF. Defaults to False. """ super().__init__() self.nup = nup self.ndown = ndown self.nelec = nup + ndown self.cuda = cuda self.device = torch.device('cpu') if self.cuda: self.device = torch.device('cuda') self.atoms = atomic_pos.to(self.device) self.natoms = atomic_pos.shape[0] self.ndim = 3 # kernel function self.jastrow_kernel = jastrow_kernel(nup, ndown, atomic_pos, cuda, **kernel_kwargs) # requires autograd to compute derivatives self.requires_autograd = self.jastrow_kernel.requires_autograd # elec-nuc distances self.edist = ElectronNucleiDistance( self.nelec, self.atoms, self.ndim)
[docs] def forward(self, pos, derivative=0, sum_grad=True): """Compute the Jastrow factors. Args: pos (torch.tensor): Positions of the electrons Size : Nbatch, Nelec x Ndim derivative (int, optional): order of the derivative (0,1,2,). Defaults to 0. sum_grad (bool, optional): Return the sum_grad (i.e. the sum of the derivatives) or the individual terms. Defaults to True. False only for derivative=1 Returns: torch.tensor: value of the jastrow parameter for all confs derivative = 0 (Nmo) x Nbatch x 1 derivative = 1 (Nmo) x Nbatch x Nelec (for sum_grad = True) derivative = 1 (Nmo) x Nbatch x Ndim x Nelec (for sum_grad = False) derivative = 2 (Nmo) x Nbatch x Nelec """ size = pos.shape assert size[1] == self.nelec * self.ndim r = self.edist(pos) kern_vals = self.jastrow_kernel(r) jast = torch.exp(kern_vals.sum([-1, -2])).unsqueeze(-1) if derivative == 0: return jast elif derivative == 1: dr = self.edist(pos, derivative=1) return self.jastrow_factor_derivative(r, dr, jast, sum_grad) elif derivative == 2: dr = self.edist(pos, derivative=1) d2r = self.edist(pos, derivative=2) return self.jastrow_factor_second_derivative(r, dr, d2r, jast) elif derivative == [0, 1, 2]: dr = self.edist(pos, derivative=1) d2r = self.edist(pos, derivative=2) return(jast, self.jastrow_factor_derivative( r, dr, jast, sum_grad), self.jastrow_factor_second_derivative(r, dr, d2r, jast))
[docs] def jastrow_factor_derivative(self, r, dr, jast, sum_grad): """Compute the value of the derivative of the Jastrow factor Args: r (torch.tensor): ee distance matrix Nbatch x Nelec x Nelec jast (torch.tensor): values of the jastrow elements Nbatch x Nelec x Natom Returns: torch.tensor: gradient of the jastrow factors Nbatch x Ndim x Nelec """ nbatch = r.shape[0] if sum_grad: djast = self.jastrow_kernel.compute_derivative( r, dr).sum((1, 3)) return djast * jast else: djast = self.jastrow_kernel.compute_derivative( r, dr).sum(3) return djast * jast.unsqueeze(-1)
[docs] def jastrow_factor_second_derivative(self, r, dr, d2r, jast): """Compute the value of the pure 2nd derivative of the Jastrow factor Args: r (torch.tensor): ee distance matrix Nbatch x Nelec x Nelec jast (torch.tensor): values of the ajstrow elements Nbatch x Nelec x Nelec Returns: torch.tensor: diagonal hessian of the jastrow factors Nbatch x Nelec x Ndim """ nbatch = r.shape[0] # pure second derivative terms d2jast = self.jastrow_kernel.compute_second_derivative( r, dr, d2r).sum((1, 3)) # mixed terms djast = self.jastrow_kernel.compute_derivative(r, dr) djast = ((djast.sum(3))**2).sum(1) # add partial derivative hess_jast = d2jast + djast return hess_jast * jast