Source code for qmctorch.utils.algebra_utils

import torch
import numpy as np
from typing import List
from scipy.special import factorial2 as f2


[docs] def btrace(M: torch.Tensor) -> torch.Tensor: """Computes the trace of batched matrices Args: M: matrices of size (Na, Nb, ... Nx, N, N) Returns: trace of matrices (Na, Nb, ... Nx) """ return torch.diagonal(M, dim1=-2, dim2=-1).sum(-1)
[docs] def bproj(M: torch.Tensor, P: torch.Tensor) -> torch.Tensor: """Project batched matrices using P^T M P Args: M (torch.Tensor): Batched matrices of size (..., N, M) P (torch.Tensor): Projectors of size (..., N, M) Returns: torch.Tensor: Projected matrices """ return P.transpose(-1, -2) @ M @ P
[docs] def bdet2(M: torch.Tensor) -> torch.Tensor: """Computes the determinant of batched 2x2 matrices Args: M (torch.tensor): input matrices Returns: torch.tensor: determinants of the matrices """ return M[..., 0, 0] * M[..., 1, 1] - M[..., 0, 1] * M[..., 1, 0]
[docs] def double_factorial(input: List) -> np.ndarray: """Computes the double factorial of an array of int Args: input (List): input numbers Returns: List: values of the double factorial """ output = f2(input) return np.array([1 if o == 0 else o for o in output])
[docs] class BatchDeterminant(torch.autograd.Function):
[docs] @staticmethod def forward(ctx, input): # LUP decompose the matrices inp_lu, pivots = input.lu() _, _, inpu = torch.lu_unpack(inp_lu, pivots) # get the number of permuations s = ( (pivots != torch.as_tensor(range(1, input.shape[1] + 1)).int()) .sum(1) .type(torch.get_default_dtype()) ) # get the prod of the diag of U d = torch.diagonal(inpu, dim1=-2, dim2=-1).prod(1) # assemble det = (-1) ** s * d ctx.save_for_backward(input, det) return det
[docs] @staticmethod def backward(ctx, grad_output): """using jaobi's formula d det(A) / d A_{ij} = adj^T(A)_{ij} using the adjunct formula d det(A) / d A_{ij} = ( (det(A) A^{-1})^T )_{ij} """ input, det = ctx.saved_tensors return (grad_output * det).view(-1, 1, 1) * torch.inverse(input).transpose(1, 2)