Source code for qmctorch.wavefunction.jastrows.elec_elec.orbital_dependent_jastrow_kernel


import torch
from torch import nn
from torch.autograd import grad
from .kernels.jastrow_kernel_electron_electron_base import JastrowKernelElectronElectronBase


[docs]class OrbitalDependentJastrowKernel(JastrowKernelElectronElectronBase): def __init__(self, nup, ndown, nmo, cuda, jastrow_kernel, kernel_kwargs={}): """Transform a kernel into a orbital dependent kernel Args: nup (int): number of spin up electrons ndown (int): number of spin down electron nmo (int): number of orbital cuda (bool): use GPUs jastrow_kernel (kernel function): kernel to be used kernel_kwargs (dict): keyword arguments of the kernel """ super().__init__(nup, ndown, cuda) self.nmo = nmo self.jastrow_functions = nn.ModuleList( [jastrow_kernel(nup, ndown, cuda, **kernel_kwargs) for _ in range(self.nmo)])
[docs] def forward(self, r): """ Get the jastrow kernel. Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec_pair Returns: torch.tensor: matrix of the jastrow kernels Nmo x Nbatch x Nelec_pair """ out = None for jast in self.jastrow_functions: jvals = jast(r).unsqueeze(0) if out is None: out = jvals else: out = torch.cat((out, jvals), axis=0) return out
[docs] def compute_derivative(self, r, dr): """Get the elements of the derivative of the jastrow kernels wrt to the first electrons Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec_pair dr (torch.tensor): matrix of the derivative of the e-e distances Nbatch x Ndim x Nelec_pair Returns: torch.tensor: matrix fof the derivative of the jastrow elements Nmo x Nbatch x Ndim x Nelec_pair """ out = None if r.requires_grad == False: r.requires_grad = True with torch.enable_grad(): for jast in self.jastrow_functions: kernel = jast(r) ker_grad = self._grads(kernel, r) ker_grad = ker_grad.unsqueeze(1) * dr ker_grad = ker_grad.unsqueeze(0).detach().clone() if out is None: out = ker_grad else: out = torch.cat((out, ker_grad), axis=0) return out
[docs] def compute_second_derivative(self, r, dr, d2r): """Get the elements of the pure 2nd derivative of the jastrow kernels wrt to the first electron Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec_pair dr (torch.tensor): matrix of the derivative of the e-e distances Nbatch x Ndim x Nelec_pair d2r (torch.tensor): matrix of the 2nd derivative of the e-e distances Nbatch x Ndim x Nelec_pair Returns: torch.tensor: matrix fof the pure 2nd derivative of the jastrow elements Nmo x Nbatch x Ndim x Nelec_pair """ dr2 = dr * dr out = None if r.requires_grad == False: r.requires_grad = True with torch.enable_grad(): for jast in self.jastrow_functions: kernel = jast(r) ker_hess, ker_grad = self._hess(kernel, r) jhess = (ker_hess).unsqueeze(1) * \ dr2 + ker_grad.unsqueeze(1) * d2r jhess = jhess.unsqueeze(0) if out is None: out = jhess else: out = torch.cat((out, jhess)) return out
@staticmethod def _grads(val, r): """Get the gradients of the jastrow values of a given orbital terms Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec_pair Returns: torch.tensor: gradients of the values wrt to ee distance """ return grad(val, r, grad_outputs=torch.ones_like(val))[0] @staticmethod def _hess(val, r): """get the hessian of the jastrow values. of a given orbital terms Warning thos work only because the orbital term are dependent of a single rij term, i.e. fij = f(rij) Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec_pair Returns: torch.tensor: second derivative of the values wrt to ee distance """ gval = grad(val, r, grad_outputs=torch.ones_like(val), create_graph=True)[0] hval = grad(gval, r, grad_outputs=torch.ones_like(gval))[0] return hval, gval