qmctorch.utils.torch_utils module

qmctorch.utils.torch_utils.set_torch_double_precision()[source]

Set the default precision to double for all torch tensors.

qmctorch.utils.torch_utils.set_torch_single_precision()[source]

Set the default precision to single for all torch tensors.

qmctorch.utils.torch_utils.fast_power(x, k, mask0=None, mask2=None)[source]

Computes x**k when k have elements 0, 1, 2

Parameters:
  • x (torch.tensor) – input

  • k (torch.tensor) – exponents

  • mask0 (torch.tensor) – precomputed mask of the elements of that are 0 (Defaults to None and computed here)

  • mask2 (torch.tensor) – precomputed mask of the elements of that are 2 (Defaults to None and computed here)

Returns:

values of x**k

Return type:

torch.tensor

qmctorch.utils.torch_utils.gradients(out, inp)[source]

Return the gradients of out wrt inp

Parameters:
  • out ([type]) – [description]

  • inp ([type]) – [description]

qmctorch.utils.torch_utils.diagonal_hessian(out, inp, return_grads=False)[source]

return the diagonal hessian of out wrt to inp

Parameters:
  • out ([type]) – [description]

  • inp ([type]) – [description]

Returns:

[description]

Return type:

[type]

class qmctorch.utils.torch_utils.DataSet(*args: Any, **kwargs: Any)[source]

Bases: Dataset

Creates a torch data set

Parameters:

data (data {torch.tensor} --) –

class qmctorch.utils.torch_utils.Loss(*args: Any, **kwargs: Any)[source]

Bases: Module

Defines the loss to use during the optimization

Parameters:

used (wf {WaveFunction} -- wave function object) –

Keyword Arguments:
  • (default (method {str} -- method to use) – {‘energy’}) (energy, variance, weighted-energy, weighted-variance)

  • from (clip {bool} -- clip the values that are +/- % sigma away) – the mean (default: {False})

forward(pos, no_grad=False, deactivate_weight=False)[source]

Computes the loss

Parameters:

batch (pos {torch.tensor} -- positions of the walkers in that) –

Keyword Arguments:

loss (no_grad {bool} -- computes the gradient of the) – (default: {False})

Returns:

torch.tensor, torch.tensor – value of the loss, local energies

static get_grad_mode(no_grad)[source]

Returns enable_grad or no_grad

Parameters:

[description] (no_grad {bool} --) –

get_clipping_mask(local_energies)[source]

computes the clipping mask

Parameters:

energies (local_energies {torch.tensor} -- values of the local) –

get_sampling_weights(pos, deactivate_weight)[source]

Get the weight needed when resampling is not done at every step

class qmctorch.utils.torch_utils.OrthoReg(*args: Any, **kwargs: Any)[source]

Bases: Module

Add a penalty loss to keep the MO orthogonalized

Keyword Arguments:

(default (alpha {float} -- strength of the penaly) – {0.1})

forward(W)[source]

Return the loss : |W x W^T - I|.