Source code for qmctorch.wavefunction.orbitals.backflow.orbital_dependent_backflow_kernel

import torch
from torch import nn


[docs]class OrbitalDependentBackFlowKernel(nn.Module): def __init__(self, backflow_kernel, backflow_kernel_kwargs, mol, cuda): """Compute orbital dependent back flow kernel, i.e. the functions f(rij) where rij is the distance between electron i and j This kernel is used in the backflow transformation .. math: q^{\\alpha}_i = r_i + \\sum_{j\\neq i} f^{\\alpha}(r_{ij}) (r_i-r_j) where :math: `f^{\\alpha}(r_{ij})` is the kernel for obital :math: `\\alpha` """ super().__init__() self.nelec = mol.nelec self.nao = mol.basis.nao self.orbital_dependent_kernel = nn.ModuleList( [backflow_kernel(mol, cuda, **backflow_kernel_kwargs) for iao in range(self.nao)]) self.cuda = cuda self.device = torch.device('cpu') if self.cuda: self.device = torch.device('cuda') # domension along which the different orbitals are stacked # with stach_axis = 1 the resulting tensors will have dimension # Nbatch x Nao x ... self.stack_axis = 1
[docs] def forward(self, ree, derivative=0): """Computes the desired values of the kernels Args: ree (torch.tensor): e-e distance Nbatch x Nelec x Nelec derivative (int): derivative requried 0, 1, 2 Returns: torch.tensor : f(r) Nbatch x Nao x Nelec x Nelec """ out = None for ker in self.orbital_dependent_kernel: ker_val = ker(ree, derivative).unsqueeze(self.stack_axis) if out is None: out = ker_val else: out = torch.cat((out, ker_val), axis=self.stack_axis) return out