import torch
from torch import nn
from .scaling import (
get_scaled_distance,
get_der_scaled_distance,
get_second_der_scaled_distance,
)
[docs]
class ElectronElectronDistance(nn.Module):
def __init__(self, nelec, ndim=3, scale=False, scale_factor=0.6):
"""Computes the electron-electron distances
.. math::
r_{ij} = \\sqrt{ (x_i-x_j)^2 + (y_i-y_j)^2 + (z_i-z_j)^2}
Args:
nelec (int): number of electrons
ndim (int): number of spatial dimensions
scale(bool, optional): return scaled values, Defaults to False
scale_factor(float, optional): value of the scale factor,
Defaults to 0.6
Examples::
>>> edist = ElectronDistance(2,3)
>>> pos = torch.tensor(500,6)
>>> r = edist(pos)
>>> dr = edist(pos,derivative=1)
"""
super().__init__()
self.nelec = nelec
self.ndim = ndim
self.scale = scale
self.kappa = scale_factor
_type_ = torch.get_default_dtype()
if _type_ == torch.float32:
self.eps = 1e-6
elif _type_ == torch.float64:
self.eps = 1e-16
[docs]
def forward(self, input, derivative=0):
"""Compute the pairwise distance between the electrons
or its derivative. \n
When required, the derivative is computed wrt to the first electron i.e.
.. math::
\\frac{dr_{ij}}{dx_i}
which is different from :
.. math::
\\frac{d r_{ij}}{dx_j} = -\\frac{dr_{ij}}{dx_i}
Args:
input (torch.tesnor): position of the electron \n
size : Nbatch x [Nelec x Ndim]
derivative (int, optional): degre of the derivative. \n
Defaults to 0.
Returns:
torch.tensor: distance (or derivative) matrix \n
Nbatch x Nelec x Nelec if derivative = 0 \n
Nbatch x Ndim x Nelec x Nelec if derivative = 1,2
"""
# get the distance matrices
input_ = input.view(-1, self.nelec, self.ndim)
dist = self.get_distance_quadratic(input_)
dist = self.safe_sqrt(dist)
if derivative == 0:
if self.scale:
return get_scaled_distance(self.kappa, dist)
else:
return dist
elif derivative == 1:
der_dist = self.get_der_distance(input_, dist)
if self.scale:
return get_der_scaled_distance(self.kappa, dist, der_dist)
else:
return der_dist
elif derivative == 2:
d2_dist = self.get_second_der_distance(input_, dist)
if self.scale:
der_dist = self.get_der_distance(input_, dist)
return get_second_der_scaled_distance(
self.kappa, dist, der_dist, d2_dist
)
else:
return d2_dist
[docs]
def safe_sqrt(self, dist):
"""Compute the square root of the electron electron distance matrix.
Args:
dist (torch.tensor): ee distances squared
Nbatch x Nelec x Nelec
Returns:
torch.tensor: sqrt of dist Nbatch x Nelec x Nelec
"""
# epsilon on the diag needed for back prop
eps_ = self.eps * torch.diag(dist.new_ones(dist.shape[-1])).expand_as(dist)
# extact the diagonal as diag can be negative someties
# due to numerical noise
diag = torch.diag_embed(torch.diagonal(dist, dim1=-1, dim2=-2))
# remove diagonal and add eps for backprop
dist = torch.sqrt(dist - diag + eps_)
return dist
[docs]
def get_der_distance(self, pos, dist):
"""Get the derivative of the electron electron distance matrix.
.. math::
\\frac{d r_{ij}}{d x_i}
Args:
pos (torch.tensor): positions of the electrons
Nbatch x Nelec x Ndim
dist (torch.tensor): distance matrix between the elecs
Nbatch x Nelec x Nelec
Returns:
[type]: [description]
"""
eps_ = self.eps * torch.diag(dist.new_ones(dist.shape[-1])).expand_as(dist)
invr = (1.0 / (dist + eps_)).unsqueeze(1)
diff_axis = pos.transpose(1, 2).unsqueeze(3)
diff_axis = diff_axis - diff_axis.transpose(2, 3)
return diff_axis * invr
[docs]
def get_second_der_distance(self, pos, dist):
"""Get the second derivative of the electron electron distance matrix.
.. math::
\\frac{d^2 r_{ij}}{d x_i^2}
Args:
pos (torch.tensor): positions of the electrons
Nbatch x Nelec x Ndim
dist (torch.tensor): distance matrix between the elecs
Nbatch x Nelec x Nelec
Returns:
[type]: [description]
"""
eps_ = self.eps * torch.diag(dist.new_ones(dist.shape[-1])).expand_as(dist)
invr3 = (1.0 / (dist**3 + eps_)).unsqueeze(1)
diff_axis = pos.transpose(1, 2).unsqueeze(3)
diff_axis = (diff_axis - diff_axis.transpose(2, 3)) ** 2
diff_axis = diff_axis[:, [[1, 2], [2, 0], [0, 1]], ...].sum(2)
return diff_axis * invr3
[docs]
@staticmethod
def get_distance_quadratic(pos):
"""Compute the distance following a quadratic expansion
Arguments:
pos {torch.tensor} -- electron position [nbatch x nelec x ndim]
Returns:
torch.tensor -- distance matrices nbatch x nelec x ndim]
"""
norm = (pos**2).sum(-1).unsqueeze(-1)
dist = norm + norm.transpose(1, 2) - 2.0 * torch.bmm(pos, pos.transpose(1, 2))
return dist
[docs]
@staticmethod
def get_difference(pos):
"""Compute the difference ri - rj
Arguments:
pos {torch.tensor} -- electron position [nbatch x nelec x ndim]
Returns:
torch.tensor -- distance matrices nbatch x nelec x nelec x ndim]
"""
out = pos[:, :, None, :] - pos[:, None, :, :]
return out