Source code for qmctorch.utils.torch_utils

from typing import Optional, Tuple
import torch
from torch.autograd import grad, Variable
from torch.utils.data import Dataset
from math import ceil


[docs] def set_torch_double_precision() -> None: """Set the default precision to double for all torch tensors.""" torch.set_default_dtype(torch.float64) torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False
# torch.set_default_tensor_type(torch.DoubleTensor)
[docs] def set_torch_single_precision() -> None: """Set the default precision to single for all torch tensors.""" torch.set_default_dtype(torch.float32) torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False
# torch.set_default_tensor_type(torch.FloatTensor)
[docs] def fast_power( x: torch.Tensor, k: torch.Tensor, mask0: Optional[torch.Tensor] = None, mask2: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Computes x**k when k have elements 0, 1, 2. Args: x (torch.Tensor): input k (torch.Tensor): exponents mask0 (torch.Tensor): precomputed mask of the elements of that are 0 (Defaults to None and computed here) mask2 (torch.Tensor): precomputed mask of the elements of that are 2 (Defaults to None and computed here) Returns: torch.Tensor: values of x**k """ kmax = 3 if k.max() < kmax: out = x.clone() if mask0 is None: mask0 = k == 0 out.masked_fill_(mask0, 1) if k.max() > 1: if mask2 is None: mask2 = k == 2 out[..., mask2] *= out[..., mask2] else: out = x**k return out
[docs] def gradients(out: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: """ Return the gradients of out wrt inp Args: out (torch.Tensor): The output tensor inp (torch.Tensor): The input tensor Returns: torch.Tensor: Gradient of out wrt inp """ gval = grad(out, inp, grad_outputs=torch.ones_like(out))[0] return gval.detach()
[docs] def hessian(out: torch.Tensor, inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute the Hessian and the gradient of `out` wrt `inp`. Args: out (torch.Tensor): The output tensor. inp (torch.Tensor): The input tensor. Returns: Tuple[torch.Tensor, torch.Tensor]: The Hessian and the gradient. """ gval = grad(out, inp, grad_outputs=torch.ones_like(out), create_graph=True)[0] hval = grad(gval, inp, grad_outputs=torch.ones_like(gval))[0] return hval, gval.detach()
[docs] def diagonal_hessian( out: torch.Tensor, inp: torch.Tensor, ) -> torch.Tensor: """Return the diagonal Hessian of `out` with respect to `inp`. Args: out (torch.Tensor): The output tensor. inp (torch.Tensor): The input tensor. return_grads (bool, optional): Whether to return gradients. Defaults to False. Returns: torch.Tensor: Diagonal elements of the Hessian. torch.Tensor (optional): Gradients of `out` with respect to `inp` if `return_grads` is True. """ # compute the jacobian jacob = grad( out, inp, grad_outputs=torch.ones_like(out), only_inputs=True, create_graph=True )[0] # compute the diagonal element of the Hessian z = Variable(torch.ones(jacob.shape[0])) hess = torch.zeros_like(jacob) for idim in range(jacob.shape[1]): tmp = grad( jacob[:, idim], inp, grad_outputs=z, only_inputs=True, create_graph=True )[0] hess[:, idim] = tmp[:, idim] return hess, jacob.detach()
[docs] class DataSet(Dataset): def __init__(self, data: torch.Tensor) -> None: """Creates a torch data set Arguments: data (torch.Tensor): data """ def __len__(self) -> int: """get the number of data points Returns: int -- number of data points """ return self.data.shape[0] def __getitem__(self, index: int) -> torch.Tensor: """returns a given data point Arguments: index {int} -- index of the point Returns: torch.tensor -- data of that point """ return self.data[index, :]
[docs] class DataLoader: def __init__( self, data: torch.Tensor, batch_size: int, pin_memory: bool = False ) -> None: """Simple DataLoader to replace torch data loader Args: data (torch.Tensor): data to load [Nbatch,Nelec*3] batch_size (int): size of the minibatch pin_memory (bool, optional): copy the data to pinned memory. Defaults to False. """ if pin_memory: self.dataset = data.pin_memory() else: self.dataset = data self.len = len(data) self.nbatch = ceil(self.len / batch_size) self.count = 0 self.batch_size = batch_size def __iter__(self): """Initialize the iterator. Returns: DataLoader: The iterator instance. """ self.count = 0 return self def __next__(self) -> torch.Tensor: """Returns the next batch of data points. Returns: torch.Tensor: The next batch of data points. Raises: StopIteration: If there are no more batches to return. """ if self.count < self.nbatch - 1: out = self.dataset[ self.count * self.batch_size : (self.count + 1) * self.batch_size ] self.count += 1 return out elif self.count == self.nbatch - 1: out = self.dataset[self.count * self.batch_size :] self.count += 1 return out else: raise StopIteration