qmctorch.wavefunction.jastrows.jastrow_factor_combined_terms module
- class qmctorch.wavefunction.jastrows.jastrow_factor_combined_terms.JastrowFactorCombinedTerms(*args: Any, **kwargs: Any)[source]
Bases:
Module
[summary]
- Parameters:
nup (int) – number of spin up electron
ndown (int) – number opf spin down electron
atomic_pos (torch tensor) – atomic positions
jastrow_kernel ([dict]) – kernels of the jastrow factor
jastrow_kernel_kwargs (dict) – keyword argument of the kernels
cuda (bool, optional) – [description]. Defaults to False.
- forward(pos, derivative=0, sum_grad=True)[source]
Compute the Jastrow factors.
- Parameters:
- Returns:
- value of the jastrow parameter for all confs
- derivative = 0 (Nmo) x Nbatch x 1
- derivative = 1 (Nmo) x Nbatch x Nelec
(for sum_grad = True)
- derivative = 1 (Nmo) x Nbatch x Ndim x Nelec
(for sum_grad = False)
- Return type:
torch.tensor
- static get_derivative_combined_values(jast_vals, djast_vals)[source]
Compute the derivative of the product. .. math:
J = A * B * C frac{d J}{dx} = frac{d A}{dx} B C + A frac{d B}{dx} C + A B frac{d C}{dx}
- static get_second_derivative_combined_values(jast_vals, djast_vals, d2jast_vals)[source]
Compute the derivative of the product. .. math:
J = A * B * C frac{d^2 J}{dx^2} = frac{d^2 A}{dx^2} B C + A frac{d^2 B}{dx^2} C + A B frac{d^2 C}{dx^2}
2( frac{d A}{dx} frac{dB}{dx} C + frac{d A}{dx} B frac{dC}{dx} + A frac{d B}{dx} frac{dC}{dx} )