Source code for qmctorch.wavefunction.jastrows.elec_elec_nuclei.kernels.jastrow_kernel_electron_electron_nuclei_base

import torch
from torch import nn
from torch.autograd import grad
from torch.autograd.variable import Variable


[docs]class JastrowKernelElectronElectronNucleiBase(nn.Module): def __init__(self, nup, ndown, atomic_pos, cuda, **kwargs): r"""Base Class for the elec-elec-nuc jastrow kernel Args: nup (int): number of spin up electons ndow (int): number of spin down electons atoms (torch.tensor): atomic positions of the atoms cuda (bool, optional): Turns GPU ON/OFF. Defaults to False. """ super().__init__() self.nup, self.ndown = nup, ndown self.cuda = cuda self.nelec = nup + ndown self.atoms = atomic_pos self.natoms = atomic_pos.shape[0] self.ndim = 3 self.device = torch.device('cpu') if self.cuda: self.device = torch.device('cuda') self.requires_autograd = True
[docs] def forward(self, x): """Compute the values of the kernel Args: x (torch.tensor): e-e and e-n distances distance (Nbatch, Natom, Nelec_pairs, 3) the last dimension holds the values [R_{iA}, R_{jA}, r_{ij}] in that order. Returns: torch.tensor: values of the kernel (Nbatch, Natom, Nelec_pairs, 1) """ raise NotImplementedError()
[docs] def compute_derivative(self, r, dr): """Get the elements of the derivative of the jastrow kernels.""" kernel = self.forward(r) ker_grad = self._grads(kernel, r) # return the ker * dr out = ker_grad.unsqueeze(1) * dr # sum over the atoms return out
[docs] def compute_second_derivative(self, r, dr, d2r): """Get the elements of the pure 2nd derivative of the jastrow kernels. """ dr2 = dr * dr kernel = self.forward(r) ker_hess, ker_grad = self._hess(kernel, r, self.device) jhess = ker_hess.unsqueeze(1) * \ dr2 + ker_grad.unsqueeze(1) * d2r return jhess
@staticmethod def _grads(val, pos): """Get the gradients of the jastrow values of a given orbital terms Args: pos ([type]): [description] Returns: [type]: [description] """ return grad(val, pos, grad_outputs=torch.ones_like(val))[0] @staticmethod def _hess(val, pos, device): """get the hessian of the jastrow values. Args: pos ([type]): [description] """ gval = grad(val, pos, grad_outputs=torch.ones_like(val), create_graph=True)[0] grad_out = Variable(torch.ones( *gval.shape[:-1])).to(device) hval = torch.zeros_like(gval).to(device) for idim in range(gval.shape[-1]): tmp = grad(gval[..., idim], pos, grad_outputs=grad_out, only_inputs=True, create_graph=True)[0] hval[..., idim] = tmp[..., idim] return hval, gval