Source code for qmctorch.wavefunction.jastrows.elec_elec.kernels.pade_jastrow_polynomial_kernel

import torch
from torch import nn

from .....utils import register_extra_attributes
from .jastrow_kernel_electron_electron_base import JastrowKernelElectronElectronBase


[docs]class PadeJastrowPolynomialKernel(JastrowKernelElectronElectronBase): def __init__(self, nup, ndown, cuda, order=2, weight_a=None, weight_b=None): """Computes a polynomial Pade-Jastrow factor .. math:: B_{ij} = \\frac{P_{ij}}{Q_{ij}} with : .. math:: P_{ij} = a_1 r_{i,j} + a_2 r_{ij}^2 + .... and : .. math Q_{ij} = 1 + b_1 r_{i,j} + b_2 r_{ij}^2 + ... Args: nup (int): number of spin up electons ndow (int): number of spin down electons order (int): degree of the polynomial weight_a (torch.tensor, optional): Value of the weight weight_b (torch.tensor, optional): Value of the weight cuda (bool, optional): Turns GPU ON/OFF. Defaults to False. """ super().__init__(nup, ndown, cuda) self.porder = order self.set_variational_weights(weight_a, weight_b) self.static_weight = self.get_static_weight() self.requires_autograd = False
[docs] def get_static_weight(self): """Get the matrix of static weights Returns: torch.tensor: static weight (0.5 (0.25) for parallel(anti) spins """ bup = torch.cat((0.25 * torch.ones(self.nup, self.nup), 0.5 * torch.ones(self.nup, self.ndown)), dim=1) bdown = torch.cat((0.5 * torch.ones(self.ndown, self.nup), 0.25 * torch.ones(self.ndown, self.ndown)), dim=1) static_weight = torch.cat((bup, bdown), dim=0).to(self.device) mask_tri_up = torch.triu(torch.ones_like( static_weight), diagonal=1).type(torch.BoolTensor).to(self.device) static_weight = static_weight.masked_select(mask_tri_up) return static_weight
[docs] def set_variational_weights(self, weight_a, weight_b): """Define the initial values of the variational weights. Args: weight_a (torch.tensor or None): Value of the weight weight_b (torch.tensor or None): Value of the weight """ # that can cause a nan if too low ... w0 = 1E-5 if weight_a is not None: assert weight_a.shape[0] == self.porder self.weight_a = nn.Parameter(weight_a) else: self.weight_a = nn.Parameter(w0 * torch.ones(self.porder)) if weight_b is not None: assert weight_b.shape[0] == self.porder self.weight_b = nn.Parameter(weight_b) else: self.weight_b = nn.Parameter(w0 * torch.ones(self.porder)) self.weight_b.data[0] = 1. register_extra_attributes(self, ['weight_a']) register_extra_attributes(self, ['weight_b'])
[docs] def forward(self, r): """ Get the jastrow kernel. .. math:: B_{ij} = \\frac{P_{ij}}{Q_{ij}} Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec x Nelec Returns: torch.tensor: matrix of the jastrow kernels Nbatch x Nelec x Nelec """ num, denom = self._compute_polynoms(r) return num / denom
[docs] def compute_derivative(self, r, dr): """Get the elements of the derivative of the jastrow kernels wrt to the first electrons The derivative is given by: .. math:: \\text{out}_{k,i,j} = \\frac{P'Q - PQ'}{Q^2} with: .. math:: P_{ij} = a_1 r_{i,j} + a_2 r_{ij}^2 + .... Q_{ij} = 1 + b_1 r_{i,j} + b_2 r_{ij}^2 + and : .. math:: P'_{ij} = a_1 dr + a_2 2 r dr + a_r 3 dr r^2 + .... Q'_{ij} = b_1 dr + b_2 2 r dr + b_r 3 dr r^2 + .... Due to the properties of the derivative we have .. math:: \\frac{d B_{ij}}{d k_i} = \\frac{d B_{ij}}{d k_j} = -\\frac{d B_{ji}{d k_i} Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec x Nelec dr (torch.tensor): matrix of the derivative of the e-e distances Nbatch x Ndim x Nelec x Nelec Returns: torch.tensor: matrix fof the derivative of the jastrow elements Nbatch x Ndim x Nelec x Nelec """ num, denom = self._compute_polynoms(r) num = num.unsqueeze(1) denom = denom.unsqueeze(1) der_num, der_denom = self._compute_polynom_derivatives(r, dr) return (der_num * denom - num * der_denom) / (denom * denom)
[docs] def compute_second_derivative(self, r, dr, d2r): """Get the elements of the pure 2nd derivative of the jastrow kernels wrt to the first electron Due to the properties of the derivative we have .. math:: \\frac{d B_{ij}}{d k_i} = \\frac{d B_{ij}}{d k_j} = \\frac{d B_{ji}{d k_i} Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec x Nelec dr (torch.tensor): matrix of the derivative of the e-e distances Nbatch x Ndim x Nelec x Nelec d2r (torch.tensor): matrix of the 2nd derivative of the e-e distances Nbatch x Ndim x Nelec x Nelec Returns: torch.tensor: matrix fof the pure 2nd derivative of the jastrow elements Nbatch x Ndim x Nelec x Nelec """ num, denom = self._compute_polynoms(r) num = num.unsqueeze(1) denom = denom.unsqueeze(1) der_num, der_denom = self._compute_polynom_derivatives(r, dr) d2_num, d2_denom = self._compute_polynom_second_derivative( r, dr, d2r) out = d2_num / denom - (2 * der_num * der_denom + num * d2_denom) / ( denom * denom) + 2 * num * der_denom * der_denom / (denom * denom * denom) return out
def _compute_polynoms(self, r): """Compute the num and denom polynomials. Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec x Nelec Returns: torch.tensor, torch.tensor : p and q polynoms size Nbatch x Nelec x Nelec """ num = self.static_weight * r denom = (1.0 + self.weight_b[0] * r) riord = r.clone() for iord in range(1, self.porder): riord = riord * r num += self.weight_a[iord] * riord denom += self.weight_b[iord] * riord return num, denom def _compute_polynom_derivatives(self, r, dr): """Computes the derivatives of the polynomials. Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec x Nelec dr (torch.tensor): matrix of the derivative of the e-e distances Nbatch x Ndim x Nelec x Nelec Returns: torch.tensor, torch.tensor : p and q polynoms derivatives size Nbatch x Ndim x Nelec x Nelec """ der_num = self.static_weight * dr der_denom = self.weight_b[0] * dr r_ = r.unsqueeze(1) riord = r.unsqueeze(1) for iord in range(1, self.porder): fact = (iord + 1) * dr * riord der_num += self.weight_a[iord] * fact der_denom += self.weight_b[iord] * fact riord = riord * r_ return der_num, der_denom def _compute_polynom_second_derivative(self, r, dr, d2r): """Computes the second derivative of the polynoms. Args: r (torch.tensor): matrix of the e-e distances Nbatch x Nelec x Nelec dr (torch.tensor): matrix of the derivative of the e-e distances Nbatch x Ndim x Nelec x Nelec d2r (torch.tensor): matrix of the 2nd derivative of the e-e distances Nbatch x Ndim x Nelec x Nelec Returns: torch.tensor, torch.tensor : p and q polynoms derivatives size Nbatch x Ndim x Nelec x Nelec """ d2_num = self.static_weight * d2r d2_denom = self.weight_b[0] * d2r dr2 = dr * dr r_ = r.unsqueeze(1) rnm1 = r.unsqueeze(1) rnm2 = 1. for iord in range(1, self.porder): n = iord + 1 fact = n * (d2r * rnm1 + iord * dr2 * rnm2) d2_num += self.weight_a[iord] * fact d2_denom += self.weight_b[iord] * fact rnm2 = rnm1 rnm1 = rnm1 * r_ return d2_num, d2_denom