qmctorch.utils.torch_utils module¶
-
qmctorch.utils.torch_utils.
set_torch_double_precision
()[source]¶ Set the default precision to double for all torch tensors.
-
qmctorch.utils.torch_utils.
set_torch_single_precision
()[source]¶ Set the default precision to single for all torch tensors.
-
qmctorch.utils.torch_utils.
fast_power
(x, k, mask0=None, mask2=None)[source]¶ Computes x**k when k have elements 0, 1, 2
Parameters: - x (torch.tensor) – input
- k (torch.tensor) – exponents
- mask0 (torch.tensor) – precomputed mask of the elements of that are 0 (Defaults to None and computed here)
- mask2 (torch.tensor) – precomputed mask of the elements of that are 2 (Defaults to None and computed here)
Returns: values of x**k
Return type: torch.tensor
-
qmctorch.utils.torch_utils.
gradients
(out, inp)[source]¶ Return the gradients of out wrt inp
Parameters:
-
qmctorch.utils.torch_utils.
diagonal_hessian
(out, inp, return_grads=False)[source]¶ return the diagonal hessian of out wrt to inp
Parameters: Returns: [description]
Return type: [type]
-
class
qmctorch.utils.torch_utils.
DataSet
(data)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Creates a torch data set
Parameters: {torch.tensor} -- data (data) –
-
class
qmctorch.utils.torch_utils.
Loss
(wf, method='energy', clip=False)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Defines the loss to use during the optimization
Parameters: {WaveFunction} -- wave function object used (wf) –
Keyword Arguments: - {str} -- method to use (default (method) – {‘energy’}) (energy, variance, weighted-energy, weighted-variance)
- {bool} -- clip the values that are +/- % sigma away from (clip) – the mean (default: {False})
-
forward
(pos, no_grad=False, deactivate_weight=False)[source]¶ Computes the loss
Parameters: {torch.tensor} -- positions of the walkers in that batch (pos) – Keyword Arguments: {bool} -- computes the gradient of the loss (no_grad) – (default: {False}) Returns: torch.tensor, torch.tensor – value of the loss, local energies
-
static
get_grad_mode
(no_grad)[source]¶ Returns enable_grad or no_grad
Parameters: {bool} -- [description] (no_grad) –
-
class
qmctorch.utils.torch_utils.
OrthoReg
(alpha=0.1)[source]¶ Bases:
sphinx.ext.autodoc.importer._MockObject
Add a penalty loss to keep the MO orthogonalized
Keyword Arguments: {float} -- strength of the penaly (default (alpha) – {0.1}) -
forward
(W)[source]¶ Return the loss : |W x W^T - I|.
-