Source code for qmctorch.wavefunction.jastrows.elec_elec_nuclei.kernels.fully_connected_jastrow_kernel

import torch
from .jastrow_kernel_electron_electron_nuclei_base import JastrowKernelElectronElectronNucleiBase


[docs]class FullyConnectedJastrowKernel(JastrowKernelElectronElectronNucleiBase): def __init__(self, nup, ndown, atomic_pos, cuda): """Defines a fully connected jastrow factors.""" super().__init__(nup, ndown, atomic_pos, cuda) self.fc1 = torch.nn.Linear(3, 9, bias=True) self.fc2 = torch.nn.Linear(9, 3, bias=True) self.fc3 = torch.nn.Linear(3, 1, bias=True) torch.nn.init.uniform_(self.fc1.weight) torch.nn.init.uniform_(self.fc2.weight) torch.nn.init.uniform_(self.fc2.weight) self.fc1.weight.data *= 1E-3 self.fc2.weight.data *= 1E-3 self.fc3.weight.data *= 1E-3 self.nl_func = torch.nn.Sigmoid()
[docs] def forward(self, x): """Compute the values of the individual f_ij=f(r_ij) Args: x (torch.tensor): e-e distance Nbatch, Nele_pairs Returns: torch.tensor: values of the f_ij """ # reshape the input so that all elements # are considered independently of each other out_shape = list(x.shape)[:-1] + [1] x = x.reshape(-1, 3) x = self.fc1(x) x = 2 * self.nl_func(x) x = self.fc2(x) x = 2 * self.nl_func(x) x = self.fc3(x) x = 2 * self.nl_func(x) return x.reshape(*out_shape)