Source code for qmctorch.wavefunction.jastrows.distance.electron_electron_distance

import torch
from torch import nn
from .scaling import (get_scaled_distance,
                      get_der_scaled_distance,
                      get_second_der_scaled_distance)


[docs]class ElectronElectronDistance(nn.Module): def __init__(self, nelec, ndim=3, scale=False, scale_factor=0.6): """Computes the electron-electron distances .. math:: r_{ij} = \\sqrt{ (x_i-x_j)^2 + (y_i-y_j)^2 + (z_i-z_j)^2} Args: nelec (int): number of electrons ndim (int): number of spatial dimensions scale(bool, optional): return scaled values, Defaults to False scale_factor(float, optional): value of the scale factor, Defaults to 0.6 Examples:: >>> edist = ElectronDistance(2,3) >>> pos = torch.tensor(500,6) >>> r = edist(pos) >>> dr = edist(pos,derivative=1) """ super().__init__() self.nelec = nelec self.ndim = ndim self.scale = scale self.kappa = scale_factor _type_ = torch.get_default_dtype() if _type_ == torch.float32: self.eps = 1E-6 elif _type_ == torch.float64: self.eps = 1E-16
[docs] def forward(self, input, derivative=0): """Compute the pairwise distance between the electrons or its derivative. \n When required, the derivative is computed wrt to the first electron i.e. .. math:: \\frac{dr_{ij}}{dx_i} which is different from : .. math:: \\frac{d r_{ij}}{dx_j} = -\\frac{dr_{ij}}{dx_i} Args: input (torch.tesnor): position of the electron \n size : Nbatch x [Nelec x Ndim] derivative (int, optional): degre of the derivative. \n Defaults to 0. Returns: torch.tensor: distance (or derivative) matrix \n Nbatch x Nelec x Nelec if derivative = 0 \n Nbatch x Ndim x Nelec x Nelec if derivative = 1,2 """ # get the distance matrices input_ = input.view(-1, self.nelec, self.ndim) dist = self.get_distance_quadratic(input_) dist = self.safe_sqrt(dist) if derivative == 0: if self.scale: return get_scaled_distance(self.kappa, dist) else: return dist elif derivative == 1: der_dist = self.get_der_distance(input_, dist) if self.scale: return get_der_scaled_distance(self.kappa, dist, der_dist) else: return der_dist elif derivative == 2: d2_dist = self.get_second_der_distance(input_, dist) if self.scale: der_dist = self.get_der_distance(input_, dist) return get_second_der_scaled_distance(self.kappa, dist, der_dist, d2_dist) else: return d2_dist
[docs] def safe_sqrt(self, dist): """Compute the square root of the electron electron distance matrix. Args: dist (torch.tensor): ee distances squared Nbatch x Nelec x Nelec Returns: torch.tensor: sqrt of dist Nbatch x Nelec x Nelec """ # epsilon on the diag needed for back prop eps_ = self.eps * \ torch.diag(dist.new_ones(dist.shape[-1])).expand_as(dist) # extact the diagonal as diag can be negative someties # due to numerical noise diag = torch.diag_embed( torch.diagonal( dist, dim1=-1, dim2=-2)) # remove diagonal and add eps for backprop dist = torch.sqrt(dist - diag + eps_) return dist
[docs] def get_der_distance(self, pos, dist): """Get the derivative of the electron electron distance matrix. .. math:: \\frac{d r_{ij}}{d x_i} Args: pos (torch.tensor): positions of the electrons Nbatch x Nelec x Ndim dist (torch.tensor): distance matrix between the elecs Nbatch x Nelec x Nelec Returns: [type]: [description] """ eps_ = self.eps * \ torch.diag(dist.new_ones( dist.shape[-1])).expand_as(dist) invr = (1. / (dist + eps_)).unsqueeze(1) diff_axis = pos.transpose(1, 2).unsqueeze(3) diff_axis = diff_axis - diff_axis.transpose(2, 3) return diff_axis * invr
[docs] def get_second_der_distance(self, pos, dist): """Get the second derivative of the electron electron distance matrix. .. math:: \\frac{d^2 r_{ij}}{d x_i^2} Args: pos (torch.tensor): positions of the electrons Nbatch x Nelec x Ndim dist (torch.tensor): distance matrix between the elecs Nbatch x Nelec x Nelec Returns: [type]: [description] """ eps_ = self.eps * \ torch.diag(dist.new_ones( dist.shape[-1])).expand_as(dist) invr3 = (1. / (dist**3 + eps_)).unsqueeze(1) diff_axis = pos.transpose(1, 2).unsqueeze(3) diff_axis = (diff_axis - diff_axis.transpose(2, 3))**2 diff_axis = diff_axis[:, [ [1, 2], [2, 0], [0, 1]], ...].sum(2) return (diff_axis * invr3)
[docs] @staticmethod def get_distance_quadratic(pos): """Compute the distance following a quadratic expansion Arguments: pos {torch.tensor} -- electron position [nbatch x nelec x ndim] Returns: torch.tensor -- distance matrices nbatch x nelec x ndim] """ norm = (pos**2).sum(-1).unsqueeze(-1) dist = (norm + norm.transpose(1, 2) - 2.0 * torch.bmm(pos, pos.transpose(1, 2))) return dist
[docs] @staticmethod def get_difference(pos): """Compute the difference ri - rj Arguments: pos {torch.tensor} -- electron position [nbatch x nelec x ndim] Returns: torch.tensor -- distance matrices nbatch x nelec x nelec x ndim] """ out = pos[:, :, None, :] - pos[:, None, :, :] return out