Source code for qmctorch.utils.algebra_utils

import torch


[docs]def btrace(M): """Computes the trace of batched matrices Args: M (torch.tensor): matrices of size (Na, Nb, ... Nx, N, N) Returns: torch.tensor: trace of matrices (Na, Nb, ... Nx) Example: >>> m = torch.rand(100,5,5) >>> tr = btrace(m) """ return torch.diagonal(M, dim1=-2, dim2=-1).sum(-1)
[docs]def bproj(M, P): """Project batched marices using P^T M P Args: M (torch.tensor): batched matrices size (..., N, M) P (torch.tensor): Porjectors size (..., N, M) Returns: torch.tensor: Projected matrices """ return P.transpose(1, 2) @ M @ P
[docs]def bdet2(M): """Computes the determinant of batched 2x2 matrices Args: M (torch.tensor): input matrices Returns: torch.tensor: determinants of the matrices """ return M[..., 0, 0] * M[..., 1, 1] - M[..., 0, 1] * M[..., 1, 0]
[docs]class BatchDeterminant(torch.autograd.Function):
[docs] @staticmethod def forward(ctx, input): # LUP decompose the matrices inp_lu, pivots = input.lu() perm, inpl, inpu = torch.lu_unpack(inp_lu, pivots) # get the number of permuations s = (pivots != torch.as_tensor( range(1, input.shape[1]+1)).int()).sum(1).type(torch.get_default_dtype()) # get the prod of the diag of U d = torch.diagonal(inpu, dim1=-2, dim2=-1).prod(1) # assemble det = ((-1)**s * d) ctx.save_for_backward(input, det) return det
[docs] @staticmethod def backward(ctx, grad_output): '''using jaobi's formula d det(A) / d A_{ij} = adj^T(A)_{ij} using the adjunct formula d det(A) / d A_{ij} = ( (det(A) A^{-1})^T )_{ij} ''' input, det = ctx.saved_tensors return (grad_output * det).view(-1, 1, 1) * torch.inverse(input).transpose(1, 2)