Source code for qmctorch.wavefunction.orbitals.atomic_orbitals

import torch
from torch import nn

from .norm_orbital import atomic_orbital_norm
from .radial_functions import (radial_gaussian, radial_gaussian_pure,
                               radial_slater, radial_slater_pure)
from .spherical_harmonics import Harmonics


[docs]class AtomicOrbitals(nn.Module): def __init__(self, mol, cuda=False): """Computes the value of atomic orbitals Args: mol (Molecule): Molecule object cuda (bool, optional): Turn GPU ON/OFF Defaults to False. """ super(AtomicOrbitals, self).__init__() dtype = torch.get_default_dtype() # wavefunction data self.nelec = mol.nelec self.norb = mol.basis.nao self.ndim = 3 # make the atomic position optmizable self.atom_coords = nn.Parameter(torch.as_tensor( mol.basis.atom_coords_internal).type(dtype)) self.atom_coords.requires_grad = True self.natoms = len(self.atom_coords) self.atomic_number = mol.atomic_number # define the BAS positions. self.nshells = torch.as_tensor(mol.basis.nshells) self.nao_per_atom = torch.as_tensor(mol.basis.nao_per_atom) self.bas_coords = self.atom_coords.repeat_interleave( self.nshells, dim=0) self.nbas = len(self.bas_coords) # index for the contractions self.index_ctr = torch.as_tensor(mol.basis.index_ctr) self.nctr_per_ao = torch.as_tensor(mol.basis.nctr_per_ao) self.contract = not len(torch.unique( self.index_ctr)) == len(self.index_ctr) # get the coeffs of the bas self.bas_coeffs = torch.as_tensor( mol.basis.bas_coeffs).type(dtype) # get the exponents of the bas self.bas_exp = nn.Parameter( torch.as_tensor(mol.basis.bas_exp).type(dtype)) self.bas_exp.requires_grad = True # harmonics generator self.harmonics_type = mol.basis.harmonics_type if mol.basis.harmonics_type == 'sph': self.bas_n = torch.as_tensor(mol.basis.bas_n).type(dtype) self.harmonics = Harmonics( mol.basis.harmonics_type, bas_l=mol.basis.bas_l, bas_m=mol.basis.bas_m, cuda=cuda) elif mol.basis.harmonics_type == 'cart': self.bas_n = torch.as_tensor(mol.basis.bas_kr).type(dtype) self.harmonics = Harmonics( mol.basis.harmonics_type, bas_kx=mol.basis.bas_kx, bas_ky=mol.basis.bas_ky, bas_kz=mol.basis.bas_kz, cuda=cuda) # select the radial apart radial_dict = {'sto': radial_slater, 'gto': radial_gaussian, 'sto_pure': radial_slater_pure, 'gto_pure': radial_gaussian_pure} self.radial = radial_dict[mol.basis.radial_type] self.radial_type = mol.basis.radial_type # get the normalisation constants if hasattr(mol.basis, 'bas_norm') and False: self.norm_cst = torch.as_tensor( mol.basis.bas_norm).type(dtype) else: with torch.no_grad(): self.norm_cst = atomic_orbital_norm( mol.basis).type(dtype) self.cuda = cuda self.device = torch.device('cpu') if self.cuda: self._to_device() def __repr__(self): name = self.__class__.__name__ return name + '(%s, %s, %d -> (%d,%d) )' % (self.radial_type, self.harmonics_type, self.nelec*self.ndim, self.nelec, self.norb) def _to_device(self): """Export the non parameter variable to the device.""" self.device = torch.device('cuda') self.to(self.device) attrs = ['bas_n', 'bas_coeffs', 'nshells', 'norm_cst', 'index_ctr', 'nctr_per_ao', 'nao_per_atom'] for at in attrs: self.__dict__[at] = self.__dict__[at].to(self.device)
[docs] def forward(self, pos, derivative=[0], sum_grad=True, sum_hess=True, one_elec=False): """Computes the values of the atomic orbitals. .. math:: \phi_i(r_j) = \sum_n c_n \\text{Rad}^{i}_n(r_j) \\text{Y}^{i}_n(r_j) where Rad is the radial part and Y the spherical harmonics part. It is also possible to compute the first and second derivatives .. math:: \\nabla \phi_i(r_j) = \\frac{d}{dx_j} \phi_i(r_j) + \\frac{d}{dy_j} \phi_i(r_j) + \\frac{d}{dz_j} \phi_i(r_j) \n \\text{grad} \phi_i(r_j) = (\\frac{d}{dx_j} \phi_i(r_j), \\frac{d}{dy_j} \phi_i(r_j), \\frac{d}{dz_j} \phi_i(r_j)) \n \Delta \phi_i(r_j) = \\frac{d^2}{dx^2_j} \phi_i(r_j) + \\frac{d^2}{dy^2_j} \phi_i(r_j) + \\frac{d^2}{dz^2_j} \phi_i(r_j) Args: pos (torch.tensor): Positions of the electrons Size : Nbatch, Nelec x Ndim derivative (int, optional): order of the derivative (0,1,2,). Defaults to 0. sum_grad (bool, optional): Return the sum_grad (i.e. the sum of the derivatives) or the individual terms. Defaults to True. False only for derivative=1 sum_hess (bool, optional): Return the sum_hess (i.e. the sum of 2nd the derivatives) or the individual terms. Defaults to True. False only for derivative=1 one_elec (bool, optional): if only one electron is in input Returns: torch.tensor: Value of the AO (or their derivatives) \n size : Nbatch, Nelec, Norb (sum_grad = True) \n size : Nbatch, Nelec, Norb, Ndim (sum_grad = False) Examples:: >>> mol = Molecule('h2.xyz') >>> ao = AtomicOrbitals(mol) >>> pos = torch.rand(100,6) >>> aovals = ao(pos) >>> daovals = ao(pos,derivative=1) """ if not isinstance(derivative, list): derivative = [derivative] if not sum_grad: assert(1 in derivative) if not sum_hess: assert(2 in derivative) if one_elec: nelec_save = self.nelec self.nelec = 1 if derivative == [0]: ao = self._compute_ao_values(pos) elif derivative == [1]: ao = self._compute_first_derivative_ao_values( pos, sum_grad) elif derivative == [2]: ao = self._compute_second_derivative_ao_values( pos, sum_hess) elif derivative == [3]: ao = self._compute_mixed_second_derivative_ao_values(pos) elif derivative == [0, 1, 2]: ao = self._compute_all_ao_values(pos) else: raise ValueError( 'derivative must be 0, 1, 2, 3 or [0, 1, 2, 3], got ', derivative) if one_elec: self.nelec = nelec_save return ao
def _compute_ao_values(self, pos): """Compute the value of the ao from the xyx and r tensor Args: pos (torch.tensor): position of each elec size Nbatch, NelexNdim Returns: torch.tensor: atomic orbital values size (Nbatch, Nelec, Norb) """ xyz, r = self._process_position(pos) R = self.radial(r, self.bas_n, self.bas_exp) Y = self.harmonics(xyz) return self._ao_kernel(R, Y) def _ao_kernel(self, R, Y): """Kernel for the ao values Args: R (torch.tensor): radial part of the AOs Y (torch.tensor): harmonics part of the AOs Returns: torch.tensor: values of the AOs (with contraction) """ ao = self.norm_cst * R * Y if self.contract: ao = self._contract(ao) return ao def _compute_first_derivative_ao_values(self, pos, sum_grad): """Compute the value of the derivative of the ao from the xyx and r tensor Args: pos (torch.tensor): position of each elec size Nbatch, Nelec x Ndim sum_grad (boolean): return the sum_grad (True) or gradient (False) Returns: torch.tensor: derivative of atomic orbital values size (Nbatch, Nelec, Norb) if sum_grad size (Nbatch, Nelec, Norb, Ndim) if sum_grad=False """ if sum_grad: return self._compute_sum_gradient_ao_values(pos) else: return self._compute_gradient_ao_values(pos) def _compute_sum_gradient_ao_values(self, pos): """Compute the jacobian of the ao from the xyx and r tensor Args: pos (torch.tensor): position of each elec size Nbatch, Nelec x Ndim Returns: torch.tensor: derivative of atomic orbital values size (Nbatch, Nelec, Norb) """ xyz, r = self._process_position(pos) R, dR = self.radial(r, self.bas_n, self.bas_exp, xyz=xyz, derivative=[0, 1]) Y, dY = self.harmonics(xyz, derivative=[0, 1]) return self._sum_gradient_kernel(R, dR, Y, dY) def _sum_gradient_kernel(self, R, dR, Y, dY): """Kernel for the jacobian of the ao values Args: R (torch.tensor): radial part of the AOs dR (torch.tensor): derivative of the radial part of the AOs Y (torch.tensor): harmonics part of the AOs dY (torch.tensor): derivative of the harmonics part of the AOs Returns: torch.tensor: values of the jacobian of the AOs (with contraction) """ dao = self.norm_cst * (dR * Y + R * dY) if self.contract: dao = self._contract(dao) return dao def _compute_gradient_ao_values(self, pos): """Compute the gradient of the ao from the xyx and r tensor Args: pos (torch.tensor): position of each elec size Nbatch, Nelec x Ndim Returns: torch.tensor: derivative of atomic orbital values size (Nbatch, Nelec, Norb, Ndim) """ xyz, r = self._process_position(pos) R, dR = self.radial(r, self.bas_n, self.bas_exp, xyz=xyz, derivative=[0, 1], sum_grad=False) Y, dY = self.harmonics(xyz, derivative=[0, 1], sum_grad=False) return self._gradient_kernel(R, dR, Y, dY) def _gradient_kernel(self, R, dR, Y, dY): """Kernel for the gradient of the ao values Args: R (torch.tensor): radial part of the AOs dR (torch.tensor): derivative of the radial part of the AOs Y (torch.tensor): harmonics part of the AOs dY (torch.tensor): derivative of the harmonics part of the AOs Returns: torch.tensor: values of the gradient of the AOs (with contraction) """ nbatch = R.shape[0] bas = dR * Y.unsqueeze(-1) + R.unsqueeze(-1) * dY bas = self.norm_cst.unsqueeze(-1) * \ self.bas_coeffs.unsqueeze(-1) * bas if self.contract: ao = torch.zeros(nbatch, self.nelec, self.norb, 3, device=self.device).type(torch.get_default_dtype()) ao.index_add_(2, self.index_ctr, bas) else: ao = bas return ao def _compute_second_derivative_ao_values(self, pos, sum_hess): """Compute the values of the 2nd derivative of the ao from the xyz and r tensors Args: pos (torch.tensor): position of each elec size Nbatch, Nelec x Ndim sum_hess (boolean): return the sum_hess (True) or gradient (False) Returns: torch.tensor: derivative of atomic orbital values size (Nbatch, Nelec, Norb) if sum_hess size (Nbatch, Nelec, Norb, Ndim) if sum_hess=False """ if sum_hess: return self._compute_sum_diag_hessian_ao_values(pos) else: return self._compute_diag_hessian_ao_values(pos) def _compute_sum_diag_hessian_ao_values(self, pos): """Compute the laplacian of the ao from the xyx and r tensor Args: pos (torch.tensor): position of each elec size Nbatch, Nelec x Ndim Returns: torch.tensor: derivative of atomic orbital values size (Nbatch, Nelec, Norb) """ xyz, r = self._process_position(pos) R, dR, d2R = self.radial(r, self.bas_n, self.bas_exp, xyz=xyz, derivative=[0, 1, 2], sum_grad=False) Y, dY, d2Y = self.harmonics(xyz, derivative=[0, 1, 2], sum_grad=False) return self._sum_diag_hessian_kernel(R, dR, d2R, Y, dY, d2Y) def _sum_diag_hessian_kernel(self, R, dR, d2R, Y, dY, d2Y): """Kernel for the sum of the diag hessian of the ao values Args: R (torch.tensor): radial part of the AOs dR (torch.tensor): derivative of the radial part of the AOs d2R (torch.tensor): 2nd derivative of the radial part of the AOs Y (torch.tensor): harmonics part of the AOs dY (torch.tensor): derivative of the harmonics part of the AOs d2Y (torch.tensor): 2nd derivative of the harmonics part of the AOs Returns: torch.tensor: values of the laplacian of the AOs (with contraction) """ d2ao = self.norm_cst * \ (d2R * Y + 2. * (dR * dY).sum(3) + R * d2Y) if self.contract: d2ao = self._contract(d2ao) return d2ao def _compute_diag_hessian_ao_values(self, pos): """Compute the individual elements of the laplacian of the ao from the xyx and r tensor Args: pos (torch.tensor): position of each elec size Nbatch, Nelec x Ndim Returns: torch.tensor: derivative of atomic orbital values size (Nbatch, Nelec, Norb, 3) """ xyz, r = self._process_position(pos) R, dR, d2R = self.radial(r, self.bas_n, self.bas_exp, xyz=xyz, derivative=[0, 1, 2], sum_grad=False, sum_hess=False) Y, dY, d2Y = self.harmonics(xyz, derivative=[0, 1, 2], sum_grad=False, sum_hess=False) return self._diag_hessian_kernel(R, dR, d2R, Y, dY, d2Y) def _diag_hessian_kernel(self, R, dR, d2R, Y, dY, d2Y): """Kernel for the diagonal hessian of the ao values Args: R (torch.tensor): radial part of the AOs dR (torch.tensor): derivative of the radial part of the AOs d2R (torch.tensor): 2nd derivative of the radial part of the AOs Y (torch.tensor): harmonics part of the AOs dY (torch.tensor): derivative of the harmonics part of the AOs d2Y (torch.tensor): 2nd derivative of the harmonics part of the AOs Returns: torch.tensor: values of the laplacian of the AOs (with contraction) """ nbatch = R.shape[0] bas = self.norm_cst.unsqueeze(-1) * self.bas_coeffs.unsqueeze(-1) * \ (d2R * Y.unsqueeze(-1) + 2. * (dR * dY) + R.unsqueeze(-1) * d2Y) if self.contract: d2ao = torch.zeros(nbatch, self.nelec, self.norb, 3, device=self.device).type(torch.get_default_dtype()) d2ao.index_add_(2, self.index_ctr, bas) else: d2ao = bas return d2ao def _compute_mixed_second_derivative_ao_values(self, pos): """Compute the mixed second derivative of the ao from the xyx and r tensor Args: pos (torch.tensor): position of each elec size Nbatch, Nelec x Ndim Returns: torch.tensor: derivative of atomic orbital values size (Nbatch, Nelec, Norb) """ xyz, r = self._process_position(pos) R, dR, d2R, d2mR = self.radial(r, self.bas_n, self.bas_exp, xyz=xyz, derivative=[ 0, 1, 2, 3], sum_grad=False) Y, dY, d2Y, d2mY = self.harmonics(xyz, derivative=[0, 1, 2, 3], sum_grad=False) return self._off_diag_hessian_kernel(R, dR, d2R, d2mR, Y, dY, d2Y, d2mY) def _off_diag_hessian_kernel(self, R, dR, d2R, d2mR, Y, dY, d2Y, d2mY): """Kernel for the off diagonal hessian of the ao values Args: R (torch.tensor): radial part of the AOs dR (torch.tensor): derivative of the radial part of the AOs d2R (torch.tensor): 2nd derivative of the radial part of the AOs d2mR (torch.tensor): mixed 2nd derivative of the radial part of the AOs Y (torch.tensor): harmonics part of the AOs dY (torch.tensor): derivative of the harmonics part of the AOs d2Y (torch.tensor): 2nd derivative of the harmonics part of the AOs d2mY (torch.tensor): 2nd mixed derivative of the harmonics part of the AOs Returns: torch.tensor: values of the mixed derivative of the AOs (with contraction) """ nbatch = R.shape[0] bas = self.norm_cst.unsqueeze(-1) * self.bas_coeffs.unsqueeze(-1) * \ (d2mR * Y.unsqueeze(-1) + ((dR[..., [[0, 1], [0, 2], [1, 2]]] * dY[..., [[1, 0], [2, 0], [2, 1]]]).sum(-1)) + R.unsqueeze(-1) * d2mY) if self.contract: d2ao = torch.zeros(nbatch, self.nelec, self.norb, 3, device=self.device).type(torch.get_default_dtype()) d2ao.index_add_(2, self.index_ctr, bas) else: d2ao = bas return d2ao def _compute_all_ao_values(self, pos): """Compute the ao, gradient, laplacian of the ao from the xyx and r tensor Args: pos (torch.tensor): position of each elec size Nbatch, Nelec x Ndim sum_grad (bool): return the sum of the gradients if True sum_hess (bool): returns the sum of the diag hess if True Returns: tuple(): (ao, grad and lapalcian) of atomic orbital values ao size (Nbatch, Nelec, Norb) dao size (Nbatch, Nelec, Norb, Ndim) d2ao size (Nbatch, Nelec, Norb) """ xyz, r = self._process_position(pos) # the gradients elements are needed to compute the second der # we therefore use sum_grad=False regardless of the input arg R, dR, d2R = self.radial(r, self.bas_n, self.bas_exp, xyz=xyz, derivative=[0, 1, 2], sum_grad=False) # the gradients elements are needed to compute the second der # we therefore use sum_grad=False regardless of the input arg Y, dY, d2Y = self.harmonics(xyz, derivative=[0, 1, 2], sum_grad=False) return (self._ao_kernel(R, Y), self._gradient_kernel(R, dR, Y, dY), self._sum_diag_hessian_kernel(R, dR, d2R, Y, dY, d2Y)) def _process_position(self, pos): """Computes the positions/distance bewteen elec/orb Args: pos (torch.tensor): positions of the walkers Nbat, NelecxNdim Returns: torch.tensor, torch.tensor: positions of the elec wrt the bas (Nbatch, Nelec, Norb, Ndim) distance between elec and bas (Nbatch, Nelec, Norb) """ # get the elec-atom vectors/distances xyz, r = self._elec_atom_dist(pos) # repeat/interleave to get vector and distance between # electrons and orbitals return (xyz.repeat_interleave(self.nshells, dim=2), r.repeat_interleave(self.nshells, dim=2)) def _elec_atom_dist(self, pos): """Computes the positions/distance bewteen elec/atoms Args: pos (torch.tensor): positions of the walkers : Nbatch x [Nelec*Ndim] Returns: (torch.tensor, torch.tensor): positions of the elec wrt the atoms [Nbatch x Nelec x Natom x Ndim] distance between elec and atoms [Nbatch x Nelec x Natom] """ # compute the vectors between electrons and atoms xyz = (pos.view(-1, self.nelec, 1, self.ndim) - self.atom_coords[None, ...]) # distance between electrons and atoms r = torch.sqrt((xyz*xyz).sum(3)) return xyz, r def _contract(self, bas): """Contrat the basis set to form the atomic orbitals Args: bas (torch.tensor): values of the basis function Returns: torch.tensor: values of the contraction """ nbatch = bas.shape[0] bas = self.bas_coeffs * bas cbas = torch.zeros(nbatch, self.nelec, self.norb, device=self.device ).type(torch.get_default_dtype()) cbas.index_add_(2, self.index_ctr, bas) return cbas
[docs] def update(self, ao, pos, idelec): """Update an AO matrix with the new positions of one electron Args: ao (torch.tensor): initial AO matrix pos (torch.tensor): new positions of some electrons idelec (int): index of the electron that has moved Returns: torch.tensor: new AO matrix Examples:: >>> mol = Molecule('h2.xyz') >>> ao = AtomicOrbitals(mol) >>> pos = torch.rand(100,6) >>> aovals = ao(pos) >>> id = 0 >>> pos[:,:3] = torch.rand(100,3) >>> ao.update(aovals, pos, 0) """ ao_new = ao.clone() ids, ide = (idelec) * 3, (idelec + 1) * 3 ao_new[:, idelec, :] = self.forward( pos[:, ids:ide], one_elec=True).squeeze(1) return ao_new