Source code for qmctorch.wavefunction.jastrows.elec_elec.kernels.jastrow_kernel_electron_electron_base

import torch
from torch import nn
from torch.autograd import grad
from torch import nn


[docs]class JastrowKernelElectronElectronBase(nn.Module): def __init__(self, nup, ndown, cuda, **kwargs): r"""Base class for the elec-elec jastrow kernels Args: nup ([type]): [description] down ([type]): [description] cuda (bool, optional): [description]. Defaults to False. """ super().__init__() self.nup, self.ndown = nup, ndown self.cuda = cuda self.device = torch.device('cpu') if self.cuda: self.device = torch.device('cuda') self.requires_autograd = True
[docs] def forward(self, r): r"""Get the elements of the jastrow matrix : Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec_pair Returns: torch.tensor: matrix fof the jastrow elements Nmo x Nbatch x Nelec_pair Note: The kernel receives a [Nbatch x Npair] tensor. The kernel must first reshape that tensor to a [Nbatch*Npair, 1]. The kernel must process this tensor to another [Nbatch*Npair, 1] tensor. The kenrel must reshape the output to a [Nbatch x Npair] tensor. Example: >>> def forward(self, x): >>> nbatch, npairs = x.shape >>> x = x.reshape(-1, 1) >>> x = self.fc1(x) >>> ... >>> return(x.reshape(nbatch, npairs)) """ raise NotImplementedError()
[docs] def compute_derivative(self, r, dr): """Get the elements of the derivative of the jastrow kernels wrt to the first electrons using automatic differentiation Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec_pair dr (torch.tensor): matrix of the derivative of the e-e distances Nbatch x Ndim x Nelec_pair Returns: torch.tensor: matrix fof the derivative of the jastrow elements Nmo x Nbatch x Ndim x Nelec_pair """ if r.requires_grad == False: r.requires_grad = True with torch.enable_grad(): kernel = self.forward(r) ker_grad = self._grads(kernel, r) return ker_grad.unsqueeze(1) * dr
[docs] def compute_second_derivative(self, r, dr, d2r): """Get the elements of the pure 2nd derivative of the jastrow kernels wrt to the first electron using automatic differentiation Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec_pair dr (torch.tensor): matrix of the derivative of the e-e distances Nbatch x Ndim x Nelec_pair d2r (torch.tensor): matrix of the 2nd derivative of the e-e distances Nbatch x Ndim x Nelec_pair Returns: torch.tensor: matrix fof the pure 2nd derivative of the jastrow elements Nmo x Nbatch x Ndim x Nelec_pair """ dr2 = dr * dr if r.requires_grad == False: r.requires_grad = True with torch.enable_grad(): kernel = self.forward(r) ker_hess, ker_grad = self._hess(kernel, r) jhess = (ker_hess).unsqueeze(1) * \ dr2 + ker_grad.unsqueeze(1) * d2r return jhess
@staticmethod def _grads(val, pos): """Get the gradients of the jastrow values of a given orbital terms Args: pos ([type]): [description] Returns: [type]: [description] """ return grad(val, pos, grad_outputs=torch.ones_like(val))[0] @staticmethod def _hess(val, pos): """get the hessian of the jastrow values. of a given orbital terms Warning thos work only because the orbital term are dependent of a single rij term, i.e. fij = f(rij) Args: pos ([type]): [description] """ gval = grad(val, pos, grad_outputs=torch.ones_like(val), create_graph=True)[0] hval = grad(gval, pos, grad_outputs=torch.ones_like(gval))[0] return hval, gval