qmctorch.wavefunction.jastrows package

Subpackages

Submodules

qmctorch.wavefunction.jastrows.combine_jastrow module

class qmctorch.wavefunction.jastrows.combine_jastrow.CombineJastrow(*args: Any, **kwargs: Any)[source]

Bases: Module

Combine several Jastrow Factors

Parameters:

jastrow (list) – list of jastrow factor

forward(pos: torch.Tensor, derivative: int = 0, sum_grad: bool = True) torch.Tensor | List[torch.Tensor][source]

Compute the Jastrow factors and/or its derivatives

Parameters:
  • pos (torch.tensor) – Positions of the electrons Size: Nbatch, Nelec x Ndim

  • derivative (int, optional) – order of the derivative (0, 1, 2,). Defaults to 0.

  • sum_grad (bool, optional) – Return the sum_grad(i.e. the sum of the derivatives)

Returns:

value of the jastrow parameter for all confs

derivative = 0 (Nmo) x Nbatch x 1 derivative = 1 (Nmo) x Nbatch x Nelec

(for sum_grad = True)

derivative = 1 (Nmo) x Nbatch x Ndim x Nelec

(for sum_grad = False)

Return type:

torch.tensor

static get_combined_values(jast_vals: List[torch.Tensor]) torch.Tensor[source]

Compute the product of all terms in jast_vals.

static get_derivative_combined_values(jast_vals: List[torch.Tensor], djast_vals: List[torch.Tensor]) torch.Tensor[source]

Compute the derivative of the product of Jastrow terms.

This function calculates the first derivative of a product of Jastrow factors with respect to their input variables. The computation is based on the formula:

\[J = A * B * C \frac{d J}{dx} = \frac{d A}{dx} B C + A \frac{d B}{dx} C + A B \frac{d C}{dx}\]
Parameters:
  • jast_vals (List[torch.Tensor]) – List of Jastrow values.

  • djast_vals (List[torch.Tensor]) – List of first derivatives of Jastrow values.

Returns:

The derivative of the product of Jastrow terms.

Return type:

torch.Tensor

static get_second_derivative_combined_values(jast_vals: List[torch.Tensor], djast_vals: List[torch.Tensor], d2jast_vals: List[torch.Tensor]) torch.Tensor[source]

Compute the second derivative of the product of Jastrow terms.

This function calculates the second derivative of a product of Jastrow factors with respect to their input variables. The computation is based on the formula:

\[J = A * B * C \frac{d^2 J}{dx^2} = \frac{d^2 A}{dx^2} B C + A \frac{d^2 B}{dx^2} C + A B \frac{d^2 C}{dx^2} \ + 2( \frac{d A}{dx} \frac{dB}{dx} C + \frac{d A}{dx} B \frac{dC}{dx} + A \frac{d B}{dx} \frac{dC}{dx} )\]
Parameters:
  • jast_vals (List[torch.Tensor]) – List of Jastrow values.

  • djast_vals (List[torch.Tensor]) – List of first derivatives of Jastrow values.

  • d2jast_vals (List[torch.Tensor]) – List of second derivatives of Jastrow values.

Returns:

The combined second derivative of the Jastrow factors.

Return type:

torch.Tensor

qmctorch.wavefunction.jastrows.jastrow_factor_combined_terms module

Module contents