Source code for qmctorch.wavefunction.orbitals.backflow.backflow_transformation

import torch
from torch import nn
from typing import Dict, Optional, Callable, Tuple
from ....scf import Molecule
from .kernels.backflow_kernel_base import BackFlowKernelBase
from ...jastrows.distance.electron_electron_distance import ElectronElectronDistance


[docs] class BackFlowTransformation(nn.Module): def __init__( self, mol: Molecule, backflow_kernel: BackFlowKernelBase, backflow_kernel_kwargs: Optional[Dict] = {}, cuda: Optional[bool] = False, ): """Transform the electorn coordinates into backflow coordinates. see : Orbital-dependent backflow wave functions for real-space quantum Monte Carlo https://arxiv.org/abs/1910.07167 .. math: \\bold{q}_i = \\bold{r}_i + \\sum_{j\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j) """ super().__init__() self.nao = mol.basis.nao self.nelec = mol.nelec self.ndim = 3 self.backflow_kernel = backflow_kernel(mol, cuda, **backflow_kernel_kwargs) self.edist = ElectronElectronDistance(mol.nelec) self.cuda = cuda self.device = torch.device("cpu") if self.cuda: self.device = torch.device("cuda")
[docs] def forward(self, pos: torch.Tensor, derivative: Optional[int] = 0) -> torch.Tensor: if derivative == 0: return self._get_backflow(pos) elif derivative == 1: return self._get_backflow_derivative(pos) elif derivative == 2: return self._get_backflow_second_derivative(pos) else: raise ValueError( "derivative of the backflow transformation must be 0, 1 or 2" )
def _get_backflow(self, pos: torch.Tensor) -> torch.Tensor: """Computes the backflow transformation .. math: \\bold{q}_i = \\bold{r}_i + \\sum_{j\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j) Args: pos(torch.tensor): original positions Nbatch x[Nelec*Ndim] Returns: torch.tensor: transformed positions Nbatch x[Nelec*Ndim] """ # compute the difference # Nbatch x Nelec x Nelec x 3 delta_ee = self.edist.get_difference(pos.reshape(-1, self.nelec, self.ndim)) # compute the backflow function # Nbatch x Nelec x Nelec bf_kernel = self.backflow_kernel(self.edist(pos)) # update pos pos = pos.reshape(-1, self.nelec, self.ndim) + ( bf_kernel.unsqueeze(-1) * delta_ee ).sum(2) return pos.reshape(-1, self.nelec * self.ndim) def _get_backflow_derivative(self, pos: torch.Tensor) -> torch.Tensor: r"""Computes the derivative of the backflow transformation wrt the original positions of the electrons .. math:: \\bold{q}_i = \\bold{r}_i + \\sum_{j\\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j) .. math:: \\frac{d q_i}{d x_k} = \\delta_{ik}(1 + \\sum_{j\\neq i} \\frac{d \\eta(r_ij)}{d x_i}(x_i-x_j) + \\eta(r_ij)) + \\delta_{i\\neq k}(-\\frac{d \\eta(r_ik)}{d x_k}(x_i-x_k) - \\eta(r_ik)) Args: pos(torch.tensor): orginal positions of the electrons Nbatch x[Nelec*Ndim] Returns: torch.tensor: d q_{i}/d x_k with: q_{i} bf position of elec i x_k original coordinate of the kth elec Nelec x Nbatch x Nelec x Norb x Ndim """ # ee dist matrix : Nbatch x Nelec x Nelec ree = self.edist(pos) nbatch, nelec, _ = ree.shape # derivative ee dist matrix : Nbatch x 3 x Nelec x Nelec # dr_ij / dx_i = - dr_ij / dx_j dree = self.edist(pos, derivative=1) # difference between elec pos # Nbatch, 3, Nelec, Nelec delta_ee = self.edist.get_difference(pos.reshape(nbatch, nelec, 3)).permute( 0, 3, 1, 2 ) # backflow kernel : Nbatch x 1 x Nelec x Nelec bf = self.backflow_kernel(ree) # (d eta(r_ij) / d r_ij) (d r_ij/d beta_i) # derivative of the back flow kernel : Nbatch x 3 x Nelec x Nelec dbf = self.backflow_kernel(ree, derivative=1).unsqueeze(1) dbf = dbf * dree # (d eta(r_ij) / d beta_i) (alpha_i - alpha_j) # Nbatch x 3 x 3 x Nelec x Nelec dbf_delta_ee = dbf.unsqueeze(1) * delta_ee.unsqueeze(2) # compute the delta_ij * (1 + sum k \neq i eta(rik)) # Nbatch x Nelec x Nelec (diagonal matrix) delta_ij_bf = torch.diag_embed(1 + bf.sum(-1), dim1=-1, dim2=-2) # eye 3x3 in 1x3x3x1x1 eye_mat = torch.eye(3, 3).view(1, 3, 3, 1, 1).to(self.device) # compute the delta_ab * delta_ij * (1 + sum k \neq i eta(rik)) # Nbatch x Ndim x Ndim x Nelec x Nelec (diagonal matrix) delta_ab_delta_ij_bf = eye_mat * delta_ij_bf.view(nbatch, 1, 1, nelec, nelec) # compute sum_k df(r_ik)/dbeta_i (alpha_i - alpha_k) # Nbatch x Ndim x Ndim x Nelec x Nelec delta_ij_sum = torch.diag_embed(dbf_delta_ee.sum(-1), dim1=-1, dim2=-2) # compute delta_ab * f(rij) delta_ab_bf = eye_mat * bf.view(nbatch, 1, 1, nelec, nelec) # return Nbatch x Ndim(alpha) x Ndim(beta) x Nelec(i) x Nelec(j) # nbatch d alpha_i / d beta_j out = delta_ab_delta_ij_bf + delta_ij_sum - dbf_delta_ee - delta_ab_bf return out.unsqueeze(-1) def _get_backflow_second_derivative(self, pos: torch.Tensor) -> torch.Tensor: r"""Computes the second derivative of the backflow transformation wrt the original positions of the electrons .. math:: \\bold{q}_i = \\bold{r}_i + \\sum_{j\\neq i} \\eta(r_{ij})(\\bold{r}_i - \\bold{r}_j) .. math:: \\frac{d q_i}{d x_k} = \\delta_{ik}(1 + \\sum_{j\\neqi} \\frac{d \\eta(r_ij)}{d x_i} + \\eta(r_ij)) + \\delta_{i\\neq k}(-\\frac{d \\eta(r_ik)}{d x_k} - \\eta(r_ik)) .. math:: \\frac{d ^ 2 q_i}{d x_k ^ 2} = \\delta_{ik}(\\sum_{j\\neqi} \\frac{d ^ 2 \\eta(r_ij)}{d x_i ^ 2} + 2 \\frac{d \\eta(r_ij)}{d x_i}) + - \\delta_{i\\neq k}(\\frac{d ^ 2 \\eta(r_ik)}{d x_k ^ 2} + \\frac{d \\eta(r_ik)}{d x_k}) Args: pos(torch.tensor): orginal positions of the electrons Nbatch x[Nelec*Ndim] Returns: torch.tensor: d q_{i}/d x_k with: q_{i} bf position of elec i x_k original coordinate of the kth elec Nelec x Nbatch x Nelec x Norb x Ndim """ # ee dist matrix : # Nbatch x Nelec x Nelec ree = self.edist(pos) nbatch, nelec, _ = ree.shape # difference between elec pos # Nbatch, 3, Nelec, Nelec delta_ee = self.edist.get_difference(pos.reshape(nbatch, nelec, 3)).permute( 0, 3, 1, 2 ) # derivative ee dist matrix d r_{ij} / d x_i # Nbatch x 3 x Nelec x Nelec dree = self.edist(pos, derivative=1) # derivative ee dist matrix : d2 r_{ij} / d2 x_i # Nbatch x 3 x Nelec x Nelec d2ree = self.edist(pos, derivative=2) # derivative of the back flow kernel : d eta(r_ij)/d r_ij # Nbatch x 1 x Nelec x Nelec dbf = self.backflow_kernel(ree, derivative=1).unsqueeze(1) # second derivative of the back flow kernel : d2 eta(r_ij)/d2 r_ij # Nbatch x 1 x Nelec x Nelec d2bf = self.backflow_kernel(ree, derivative=2).unsqueeze(1) # (d^2 eta(r_ij) / d r_ij^2) (d r_ij/d x_i)^2 # + (d eta(r_ij) / d r_ij) (d^2 r_ij/d x_i^2) # Nbatch x 3 x Nelec x Nelec d2bf = (d2bf * dree * dree) + (dbf * d2ree) # (d eta(r_ij) / d r_ij) (d r_ij/d x_i) # Nbatch x 3 x Nelec x Nelec dbf = dbf * dree # eye matrix in dim x dim eye_mat = torch.eye(3, 3).reshape(1, 3, 3, 1, 1).to(self.device) # compute delta_ij delta_ab 2 sum_k dbf(ik) / dbeta_i term1 = ( 2 * eye_mat * torch.diag_embed(dbf.sum(-1), dim1=-1, dim2=-2).reshape( nbatch, 1, 3, nelec, nelec ) ) # (d2 eta(r_ij) / d2 beta_i) (alpha_i - alpha_j) # Nbatch x 3 x 3 x Nelec x Nelec d2bf_delta_ee = d2bf.unsqueeze(1) * delta_ee.unsqueeze(2) # compute sum_k d2f(r_ik)/d2beta_i (alpha_i - alpha_k) # Nbatch x Ndim x Ndim x Nelec x Nelec term2 = torch.diag_embed(d2bf_delta_ee.sum(-1), dim1=-1, dim2=-2) # compute delta_ab * df(rij)/dbeta_j term3 = 2 * eye_mat * dbf.reshape(nbatch, 1, 3, nelec, nelec) # return Nbatch x Ndim(alpha) x Ndim(beta) x Nelec(i) x Nelec(j) # nbatch d2 alpha_i / d2 beta_j out = term1 + term2 + d2bf_delta_ee + term3 return out.unsqueeze(-1)
[docs] def fit_kernel( self, lambda_func: Callable, xmin: float = 0.01, xmax: float = 1.0, npts: int = 100, lr: float = 0.001, num_epochs: int = 1000, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Fit the backflow kernel to a given function. Args: lambda_func (Callable): function to be fit xmin (float): minimum x value xmax (float): maximum x value npts (int): number of points to sample in the interval [xmin, xmax] lr (float): learning rate num_epochs (int): number of epochs to run the optimization Returns: xpts (torch.tensor): x values used for fitting ground_truth (torch.tensor): y values of the given function fit_values (torch.tensor): y values of the fit function """ xpts = torch.linspace(xmin, xmax, npts) ground_truth = lambda_func(xpts) criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam(self.backflow_kernel.parameters(), lr=lr) for epoch in range(num_epochs): running_loss = 0.0 optimizer.zero_grad() outputs = self.backflow_kernel(xpts.unsqueeze(1)) loss = criterion(outputs.squeeze(), ground_truth) loss.backward() optimizer.step() running_loss += loss.item() if epoch % 100 == 0: print("Epoch {}: Loss = {}".format(epoch, loss.detach().numpy())) fit_values = self.backflow_kernel(xpts.unsqueeze(1)).squeeze() return xpts, ground_truth, fit_values
def __repr__(self): """representation of the backflow transformation""" return self.backflow_kernel.__class__.__name__