Source code for qmctorch.wavefunction.jastrows.distance.electron_nuclei_distance

import torch
from torch import nn
from .scaling import (get_scaled_distance,
                      get_der_scaled_distance,
                      get_second_der_scaled_distance)


[docs]class ElectronNucleiDistance(nn.Module): def __init__(self, nelec, atomic_pos, ndim=3, scale=False, scale_factor=0.6): """Computes the electron-nuclei distances .. math:: r_{iA} = \\sqrt{ (x_i-x_A)^2 + (y_i-y_A)^2 + (z_i-z_A)^2} Args: nelec (int): number of electrons atomic_pos (tensor): positions of the atoms ndim (int): number of spatial dimensions scale(bool, optional): return scaled values, Defaults to False scale_factor(float, optional): value of the scale factor, Defaults to 0.6 Examples:: >>> endist = ElectronNucleiDistance(2,2) >>> epos = torch.tensor(500,6) >>> r = edist(pos) >>> dr = edist(pos,derivative=1) """ super().__init__() self.nelec = nelec self.atoms = atomic_pos self.ndim = ndim self.scale = scale self.kappa = scale_factor
[docs] def forward(self, input, derivative=0): """Compute the pairwise distances between electrons and atoms or their derivative. Args: input (torch.tesnor): position of the electron \n size : Nbatch x [Nelec x Ndim] derivative (int, optional): degre of the derivative. \n Defaults to 0. Returns: torch.tensor: distance (or derivative) matrix \n Nbatch x Nelec x Natom if derivative = 0 \n Nbatch x Ndim x Nelec x Natom if derivative = 1,2 """ # get the distance matrices input_ = input.view(-1, self.nelec, self.ndim) dist = self._get_distance_quadratic(input_, self.atoms) dist = torch.sqrt(dist) if derivative == 0: if self.scale: return get_scaled_distance(self.kappa, dist) else: return dist elif derivative == 1: der_dist = self.get_der_distance(input_, dist) if self.scale: return get_der_scaled_distance(self.kappa, dist, der_dist) else: return der_dist elif derivative == 2: d2_dist = self.get_second_der_distance(input_, dist) if self.scale: der_dist = self.get_der_distance(input_, dist) return get_second_der_scaled_distance(self.kappa, dist, der_dist, d2_dist) else: return d2_dist
[docs] def get_der_distance(self, pos, dist): """Get the derivative of the electron-nuclei distance matrix .. math:: \\frac{d r_{iA}}{d x_i} Args: pos (torch.tensor): positions of the electrons Nbatch x Nelec x Ndim dist (torch.tensor): distance matrix between the elecs Nbatch x Nelec x Nelec Returns: [type]: [description] """ invr = (1. / dist).unsqueeze(-1) diff_axis = (pos.unsqueeze(-1) - self.atoms.T).transpose(2, 3) return (diff_axis * invr).permute(0, 3, 1, 2)
[docs] def get_second_der_distance(self, pos, dist): """Get the derivative of the electron-nuclei distance matrix .. math:: \\frac{d^2 r_{iA}}{d x_i^2} Args: pos (torch.tensor): positions of the electrons Nbatch x Nelec x Ndim dist (torch.tensor): distance matrix between the elecs Nbatch x Nelec x Nelec Returns: [type]: [description] """ invr3 = (1. / (dist**3)).unsqueeze(1) diff_axis = pos.transpose(1, 2).unsqueeze(3) diff_axis = (diff_axis - self.atoms.T.unsqueeze(1))**2 diff_axis = diff_axis[:, [ [1, 2], [2, 0], [0, 1]], ...].sum(2) return (diff_axis * invr3)
@staticmethod def _get_distance_quadratic(elec_pos, atom_pos): """Compute the distance following a quadratic expansion Arguments: input {torch.tensor} -- electron position [nbatch x nelec x ndim] Returns: torch.tensor -- distance matrices nbatch x nelec x ndim] """ norm = (elec_pos**2).sum(-1).unsqueeze(-1) norm_atom = (atom_pos**2).sum(-1).unsqueeze(-1).T dist = (norm + norm_atom - 2.0 * elec_pos@atom_pos.T) return dist