Source code for qmctorch.wavefunction.orbitals.backflow.backflow_transformation

import numpy
import torch
from torch import nn
from ...jastrows.distance.electron_electron_distance import ElectronElectronDistance


[docs]class BackFlowTransformation(nn.Module): def __init__(self, mol, backflow_kernel, backflow_kernel_kwargs={}, cuda=False): """Transform the electorn coordinates into backflow coordinates. see : Orbital-dependent backflow wave functions for real-space quantum Monte Carlo https://arxiv.org/abs/1910.07167 .. math: \\bold{q}_i = \\bold{r}_i + \\sum_{j\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j) """ super().__init__() self.backflow_kernel = backflow_kernel(mol, cuda, **backflow_kernel_kwargs) self.edist = ElectronElectronDistance(mol.nelec) self.nelec = mol.nelec self.ndim = 3 self.cuda = cuda self.device = torch.device('cpu') if self.cuda: self.device = torch.device('cuda')
[docs] def forward(self, pos, derivative=0): if derivative == 0: return self._backflow(pos) elif derivative == 1: return self._backflow_derivative(pos) elif derivative == 2: return self._backflow_second_derivative(pos) else: raise ValueError( 'derivative of the backflow transformation must be 0, 1 or 2')
def _backflow(self, pos): """Computes the backflow transformation .. math: \\bold{q}_i = \\bold{r}_i + \\sum_{j\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j) Args: pos(torch.tensor): original positions Nbatch x[Nelec*Ndim] Returns: torch.tensor: transformed positions Nbatch x[Nelec*Ndim] """ # compute the difference # Nbatch x Nelec x Nelec x 3 delta_ee = self.edist.get_difference( pos.reshape(-1, self.nelec, self.ndim)) # compute the backflow function # Nbatch x Nelec x Nelec bf_kernel = self.backflow_kernel(self.edist(pos)) # update pos pos = pos.reshape(-1, self.nelec, self.ndim) + \ (bf_kernel.unsqueeze(-1) * delta_ee).sum(2) return pos.reshape(-1, self.nelec*self.ndim) def _backflow_derivative(self, pos): r"""Computes the derivative of the backflow transformation wrt the original positions of the electrons .. math:: \\bold{q}_i = \\bold{r}_i + \\sum_{j\\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j) .. math:: \\frac{d q_i}{d x_k} = \\delta_{ik}(1 + \\sum_{j\\neq i} \\frac{d \\eta(r_ij)}{d x_i}(x_i-x_j) + \\eta(r_ij)) + \\delta_{i\\neq k}(-\\frac{d \\eta(r_ik)}{d x_k}(x_i-x_k) - \\eta(r_ik)) Args: pos(torch.tensor): orginal positions of the electrons Nbatch x[Nelec*Ndim] Returns: torch.tensor: d q_{i}/d x_k with: q_{i} bf position of elec i x_k original coordinate of the kth elec Nelec x Nbatch x Nelec x Norb x Ndim """ # ee dist matrix : Nbatch x Nelec x Nelec ree = self.edist(pos) nbatch, nelec, _ = ree.shape # derivative ee dist matrix : Nbatch x 3 x Nelec x Nelec # dr_ij / dx_i = - dr_ij / dx_j dree = self.edist(pos, derivative=1) # difference between elec pos # Nbatch, 3, Nelec, Nelec delta_ee = self.edist.get_difference( pos.reshape(nbatch, nelec, 3)).permute(0, 3, 1, 2) # backflow kernel : Nbatch x 1 x Nelec x Nelec bf = self.backflow_kernel(ree) # (d eta(r_ij) / d r_ij) (d r_ij/d beta_i) # derivative of the back flow kernel : Nbatch x 3 x Nelec x Nelec dbf = self.backflow_kernel(ree, derivative=1).unsqueeze(1) dbf = dbf * dree # (d eta(r_ij) / d beta_i) (alpha_i - alpha_j) # Nbatch x 3 x 3 x Nelec x Nelec dbf_delta_ee = dbf.unsqueeze(1) * delta_ee.unsqueeze(2) # compute the delta_ij * (1 + sum k \neq i eta(rik)) # Nbatch x Nelec x Nelec (diagonal matrix) delta_ij_bf = torch.diag_embed( 1 + bf.sum(-1), dim1=-1, dim2=-2) # eye 3x3 in 1x3x3x1x1 eye_mat = torch.eye(3, 3).view(1, 3, 3, 1, 1).to(self.device) # compute the delta_ab * delta_ij * (1 + sum k \neq i eta(rik)) # Nbatch x Ndim x Ndim x Nelec x Nelec (diagonal matrix) delta_ab_delta_ij_bf = eye_mat * \ delta_ij_bf.view(nbatch, 1, 1, nelec, nelec) # compute sum_k df(r_ik)/dbeta_i (alpha_i - alpha_k) # Nbatch x Ndim x Ndim x Nelec x Nelec delta_ij_sum = torch.diag_embed( dbf_delta_ee.sum(-1), dim1=-1, dim2=-2) # compute delta_ab * f(rij) delta_ab_bf = eye_mat * bf.view(nbatch, 1, 1, nelec, nelec) # return Nbatch x Ndim(alpha) x Ndim(beta) x Nelec(i) x Nelec(j) # nbatch d alpha_i / d beta_j return delta_ab_delta_ij_bf + delta_ij_sum - dbf_delta_ee - delta_ab_bf def _backflow_second_derivative(self, pos): r"""Computes the second derivative of the backflow transformation wrt the original positions of the electrons .. math:: \\bold{q}_i = \\bold{r}_i + \\sum_{j\\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j) .. math:: \\frac{d q_i}{d x_k} = \\delta_{ik}(1 + \\sum_{j\\neqi} \\frac{d \\eta(r_ij)}{d x_i} + \\eta(r_ij)) + \\delta_{i\\neq k}(-\\frac{d \\eta(r_ik)}{d x_k} - \\eta(r_ik)) .. math:: \\frac{d ^ 2 q_i}{d x_k ^ 2} = \\delta_{ik}(\\sum_{j\\neqi} \\frac{d ^ 2 \\eta(r_ij)}{d x_i ^ 2} + 2 \\frac{d \\eta(r_ij)}{d x_i}) + - \\delta_{i\\neq k}(\\frac{d ^ 2 \\eta(r_ik)}{d x_k ^ 2} + \\frac{d \\eta(r_ik)}{d x_k}) Args: pos(torch.tensor): orginal positions of the electrons Nbatch x[Nelec*Ndim] Returns: torch.tensor: d q_{i}/d x_k with: q_{i} bf position of elec i x_k original coordinate of the kth elec Nelec x Nbatch x Nelec x Norb x Ndim """ # ee dist matrix : # Nbatch x Nelec x Nelec ree = self.edist(pos) nbatch, nelec, _ = ree.shape # difference between elec pos # Nbatch, 3, Nelec, Nelec delta_ee = self.edist.get_difference( pos.reshape(nbatch, nelec, 3)).permute(0, 3, 1, 2) # derivative ee dist matrix d r_{ij} / d x_i # Nbatch x 3 x Nelec x Nelec dree = self.edist(pos, derivative=1) # derivative ee dist matrix : d2 r_{ij} / d2 x_i # Nbatch x 3 x Nelec x Nelec d2ree = self.edist(pos, derivative=2) # derivative of the back flow kernel : d eta(r_ij)/d r_ij # Nbatch x 1 x Nelec x Nelec dbf = self.backflow_kernel(ree, derivative=1).unsqueeze(1) # second derivative of the back flow kernel : d2 eta(r_ij)/d2 r_ij # Nbatch x 1 x Nelec x Nelec d2bf = self.backflow_kernel(ree, derivative=2).unsqueeze(1) # (d^2 eta(r_ij) / d r_ij^2) (d r_ij/d x_i)^2 # + (d eta(r_ij) / d r_ij) (d^2 r_ij/d x_i^2) # Nbatch x 3 x Nelec x Nelec d2bf = (d2bf * dree * dree) + (dbf * d2ree) # (d eta(r_ij) / d r_ij) (d r_ij/d x_i) # Nbatch x 3 x Nelec x Nelec dbf = dbf * dree # eye matrix in dim x dim eye_mat = torch.eye(3, 3).reshape( 1, 3, 3, 1, 1).to(self.device) # compute delta_ij delta_ab 2 sum_k dbf(ik) / dbeta_i term1 = 2 * eye_mat * \ torch.diag_embed( dbf.sum(-1), dim1=-1, dim2=-2).reshape(nbatch, 1, 3, nelec, nelec) # (d2 eta(r_ij) / d2 beta_i) (alpha_i - alpha_j) # Nbatch x 3 x 3 x Nelec x Nelec d2bf_delta_ee = d2bf.unsqueeze(1) * delta_ee.unsqueeze(2) # compute sum_k d2f(r_ik)/d2beta_i (alpha_i - alpha_k) # Nbatch x Ndim x Ndim x Nelec x Nelec term2 = torch.diag_embed( d2bf_delta_ee.sum(-1), dim1=-1, dim2=-2) # compute delta_ab * df(rij)/dbeta_j term3 = 2 * eye_mat * dbf.reshape(nbatch, 1, 3, nelec, nelec) # return Nbatch x Ndim(alpha) x Ndim(beta) x Nelec(i) x Nelec(j) # nbatch d2 alpha_i / d2 beta_j return term1 + term2 + d2bf_delta_ee + term3