qmctorch.wavefunction.jastrows.elec_elec.kernels.jastrow_kernel_electron_electron_base module

class qmctorch.wavefunction.jastrows.elec_elec.kernels.jastrow_kernel_electron_electron_base.JastrowKernelElectronElectronBase(nup, ndown, cuda, **kwargs)[source]

Bases: sphinx.ext.autodoc.importer._MockObject

Base class for the elec-elec jastrow kernels

Parameters:
  • nup ([type]) – [description]
  • down ([type]) – [description]
  • cuda (bool, optional) – [description]. Defaults to False.
forward(r)[source]

Get the elements of the jastrow matrix :

Parameters:r (torch.tensor) – matrix of the e-e distances Nbatch x Nelec_pair
Returns:
matrix fof the jastrow elements
Nmo x Nbatch x Nelec_pair
Return type:torch.tensor

Note

The kernel receives a [Nbatch x Npair] tensor. The kernel must first reshape that tensor to a [Nbatch*Npair, 1]. The kernel must process this tensor to another [Nbatch*Npair, 1] tensor. The kenrel must reshape the output to a [Nbatch x Npair] tensor.

Example

>>> def forward(self, x):
>>>     nbatch, npairs = x.shape
>>>     x = x.reshape(-1, 1)
>>>     x = self.fc1(x)
>>>     ...
>>>     return(x.reshape(nbatch, npairs))
compute_derivative(r, dr)[source]

Get the elements of the derivative of the jastrow kernels wrt to the first electrons using automatic differentiation

Parameters:
  • r (torch.tensor) – matrix of the e-e distances Nbatch x Nelec_pair
  • dr (torch.tensor) – matrix of the derivative of the e-e distances Nbatch x Ndim x Nelec_pair
Returns:

matrix fof the derivative of the jastrow elements

Nmo x Nbatch x Ndim x Nelec_pair

Return type:

torch.tensor

compute_second_derivative(r, dr, d2r)[source]

Get the elements of the pure 2nd derivative of the jastrow kernels wrt to the first electron using automatic differentiation

Parameters:
  • r (torch.tensor) – matrix of the e-e distances Nbatch x Nelec_pair
  • dr (torch.tensor) – matrix of the derivative of the e-e distances Nbatch x Ndim x Nelec_pair
  • d2r (torch.tensor) –
    matrix of the 2nd derivative of
    the e-e distances

    Nbatch x Ndim x Nelec_pair

Returns:

matrix fof the pure 2nd derivative of

the jastrow elements Nmo x Nbatch x Ndim x Nelec_pair

Return type:

torch.tensor