Source code for qmctorch.wavefunction.pooling.slater_pooling

import torch
from torch import nn
import operator as op

from ...utils import bdet2, btrace
from .orbital_configurations import get_excitation, get_unique_excitation
from .orbital_projector import ExcitationMask, OrbitalProjector


[docs]class SlaterPooling(nn.Module): """Applies a slater determinant pooling in the active space.""" def __init__(self, config_method, configs, mol, cuda=False): """Computes the Sater determinants Args: config_method (str): method used to define the config configs (tuple): configuratin of the electrons mol (Molecule): Molecule instance cuda (bool, optional): Turns GPU ON/OFF. Defaults to False. """ super(SlaterPooling, self).__init__() self.config_method = config_method self.configs = configs self.nconfs = len(configs[0]) self.index_max_orb_up = self.configs[0].max().item() + 1 self.index_max_orb_down = self.configs[1].max().item() + 1 self.excitation_index = get_excitation(configs) self.unique_excitation, self.index_unique_excitation = get_unique_excitation( configs) self.nmo = mol.basis.nmo self.nup = mol.nup self.ndown = mol.ndown self.nelec = self.nup + self.ndown self.orb_proj = OrbitalProjector(configs, mol, cuda=cuda) self.exc_mask = ExcitationMask(self.unique_excitation, mol, (self.index_max_orb_up, self.index_max_orb_down), cuda=cuda) self.device = torch.device('cpu') if cuda: self.device = torch.device('cuda')
[docs] def forward(self, input): """Computes the values of the determinats Args: input (torch.tensor): MO matrices nbatch x nelec x nmo Returns: torch.tensor: slater determinants """ if self.config_method.startswith('cas('): return self.det_explicit(input) else: return self.det_single_double(input)
[docs] def get_slater_matrices(self, input): """Computes the slater matrices Args: input (torch.tensor): MO matrices nbatch x nelec x nmo Returns: (torch.tensor, torch.tensor): slater matrices of spin up/down """ return self.orb_proj.split_orbitals(input)
[docs] def det_explicit(self, input): """Computes the values of the determinants from the slater matrices Args: input (torch.tensor): MO matrices nbatch x nelec x nmo Returns: torch.tensor: slater determinants """ mo_up, mo_down = self.get_slater_matrices(input) return (torch.det(mo_up) * torch.det(mo_down)).transpose(0, 1)
[docs] def det_single_double(self, input): """Computes the determinant of ground state + single + double Args: input (torch.tensor): MO matrices nbatch x nelec x nmo Returns: torch.tensor: slater determinants """ # compute the determinant of the unique single excitation det_unique_up, det_unique_down = self.det_unique_single_double( input) # returns the product of spin up/down required by each excitation return (det_unique_up[:, self.index_unique_excitation[0]] * det_unique_down[:, self.index_unique_excitation[1]])
[docs] def det_ground_state(self, input): """Computes the SD of the ground state Args: input (torch.tensor): MO matrices nbatch x nelec x nmo """ return (torch.det(input[:, :self.nup, :self.nup]), torch.det(input[:, self.nup:, :self.ndown]))
[docs] def det_unique_single_double(self, input): """Computes the SD of single/double excitations The determinants of the single excitations are calculated from the ground state determinant and the ground state Slater matrices whith one column modified. See : Monte Carlo Methods in ab initio quantum chemistry B.L. Hammond, appendix B1 Note : if the state on coonfigs are specified in order we end up with excitations that comes from a deep orbital, the resulting slater matrix has one column changed (with the new orbital) and several permutation. We therefore need to multiply the slater determinant by (-1)^nperm. .. math:: MO = [ A | B ] det(Exc_{ij}) = (det(A) * A^{-1} * B)_{i,j} Args: input (torch.tensor): MO matrices nbatch x nelec x nmo """ nbatch = input.shape[0] if not hasattr(self.exc_mask, 'index_unique_single_up'): self.exc_mask.get_index_unique_single() if not hasattr(self.exc_mask, 'index_unique_double_up'): self.exc_mask.get_index_unique_double() do_single = len(self.exc_mask.index_unique_single_up) != 0 do_double = len(self.exc_mask.index_unique_double_up) != 0 # occupied orbital matrix + det and inv on spin up Aup = input[:, :self.nup, :self.nup] detAup = torch.det(Aup) # occupied orbital matrix + det and inv on spin down Adown = input[:, self.nup:, :self.ndown] detAdown = torch.det(Adown) # store all the dets we need det_out_up = detAup.unsqueeze(-1).clone() det_out_down = detAdown.unsqueeze(-1).clone() # return the ground state if self.config_method == 'ground_state': return det_out_up, det_out_down # inverse of the invAup = torch.inverse(Aup) invAdown = torch.inverse(Adown) # virtual orbital matrices spin up/down Bup = input[:, :self.nup, self.nup:self.index_max_orb_up] Bdown = input[:, self.nup:, self.ndown: self.index_max_orb_down] # compute the products of Ain and B mat_exc_up = (invAup @ Bup) mat_exc_down = (invAdown @ Bdown) if do_single: # determinant of the unique excitation spin up det_single_up = mat_exc_up.view( nbatch, -1)[:, self.exc_mask.index_unique_single_up] # determinant of the unique excitation spin down det_single_down = mat_exc_down.view( nbatch, -1)[:, self.exc_mask.index_unique_single_down] # multiply with ground state determinant # and account for permutation for deep excitation det_single_up = detAup.unsqueeze(-1) * \ det_single_up.view(nbatch, -1) # multiply with ground state determinant # and account for permutation for deep excitation det_single_down = detAdown.unsqueeze(-1) * \ det_single_down.view(nbatch, -1) # accumulate the dets det_out_up = torch.cat((det_out_up, det_single_up), dim=1) det_out_down = torch.cat( (det_out_down, det_single_down), dim=1) if do_double: # det of unique spin up double exc det_double_up = mat_exc_up.view( nbatch, -1)[:, self.exc_mask.index_unique_double_up] det_double_up = bdet2( det_double_up.view(nbatch, -1, 2, 2)) det_double_up = detAup.unsqueeze(-1) * det_double_up # det of unique spin down double exc det_double_down = mat_exc_down.view( nbatch, -1)[:, self.exc_mask.index_unique_double_down] det_double_down = bdet2( det_double_down.view(nbatch, -1, 2, 2)) det_double_down = detAdown.unsqueeze(-1) * det_double_down det_out_up = torch.cat((det_out_up, det_double_up), dim=1) det_out_down = torch.cat( (det_out_down, det_double_down), dim=1) return det_out_up, det_out_down
[docs] def operator(self, mo, bop, op=op.add, op_squared=False): """Computes the values of an opearator applied to the procuts of determinant Args: mo (torch.tensor): matrix of MO vals(Nbatch, Nelec, Nmo) bkin (torch.tensor): kinetic operator (Nbatch, Nelec, Nmo) op (operator): how to combine the up/down contribution op_squared (bool, optional) return the trace of the square of the product if True Returns: torch.tensor: kinetic energy """ # get the values of the operator if self.config_method == 'ground_state': op_vals = self.operator_ground_state(mo, bop, op_squared) elif self.config_method.startswith('single'): op_vals = self.operator_single_double(mo, bop, op_squared) elif self.config_method.startswith('cas('): op_vals = self.operator_explicit(mo, bop, op_squared) else: raise ValueError( 'Configuration %s not recognized' % self.config_method) # combine the values is necessary if op is not None: return op(*op_vals) else: return op_vals
[docs] def operator_ground_state(self, mo, bop, op_squared=False): """Computes the values of any operator on gs only Args: mo (torch.tensor): matrix of molecular orbitals bkin (torch.tensor): matrix of kinetic operator op_squared (bool, optional) return the trace of the square of the product if True Returns: torch.tensor: operator values """ # occupied orbital matrix + det and inv on spin up Aocc_up = mo[:, :self.nup, :self.nup] # occupied orbital matrix + det and inv on spin down Aocc_down = mo[:, self.nup:, :self.ndown] # inverse of the invAup = torch.inverse(Aocc_up) invAdown = torch.inverse(Aocc_down) # precompute the product A^{-1} B op_ground_up = invAup @ bop[..., :self.nup, :self.nup] op_ground_down = invAdown @ bop[..., self.nup:, :self.ndown] if op_squared: op_ground_up = op_ground_up @ op_ground_up op_ground_down = op_ground_down @ op_ground_down # ground state operator op_ground_up = btrace(op_ground_up) op_ground_down = btrace(op_ground_down) op_ground_up.unsqueeze_(-1) op_ground_down.unsqueeze_(-1) return op_ground_up, op_ground_down
[docs] def operator_explicit(self, mo, bkin, op_squared=False): r"""Computes the value of any operator using the trace trick for a product of spin up/down determinant. .. math:: -\\frac{1}{2} \Delta \Psi = -\\frac{1}{2} D_{up} D_{down} ( \Delta_{up} D_{up} / D_{up} + \Delta_{down} D_{down} / D_{down} ) Args: mo (torch.tensor): matrix of MO vals(Nbatch, Nelec, Nmo) bkin (torch.tensor): kinetic operator (Nbatch, Nelec, Nmo) op_squared (bool, optional) return the trace of the square of the product if True Returns: torch.tensor: kinetic energy """ # shortcut up/down matrices Aup, Adown = self.orb_proj.split_orbitals(mo) Bup, Bdown = self.orb_proj.split_orbitals(bkin) # check ifwe have 1 or multiple ops multiple_op = (Bup.ndim == 5) # inverse of MO matrices iAup = torch.inverse(Aup) iAdown = torch.inverse(Adown) # if we have multiple operators if multiple_op: iAup = iAup.unsqueeze(1) iAdown = iAdown.unsqueeze(1) # precompute product invA x B op_val_up = iAup @ Bup op_val_down = iAdown @ Bdown if op_squared: op_val_up = op_val_up @ op_val_up op_val_down = op_val_down @ op_val_down # kinetic terms op_val_up = btrace(op_val_up) op_val_down = btrace(op_val_down) # reshape if multiple_op: op_val_up = op_val_up.permute(1, 2, 0) op_val_down = op_val_down.permute(1, 2, 0) else: op_val_up = op_val_up.transpose(0, 1) op_val_down = op_val_down.transpose(0, 1) return (op_val_up, op_val_down)
[docs] def operator_single_double(self, mo, bop, op_squared=False): """Computes the value of any operator on gs + single + double Args: mo (torch.tensor): matrix of molecular orbitals bkin (torch.tensor): matrix of kinetic operator op_squared (bool, optional) return the trace of the square of the product if True Returns: torch.tensor: kinetic energy values """ op_up, op_down = self.operator_unique_single_double( mo, bop, op_squared) return (op_up[..., self.index_unique_excitation[0]], op_down[..., self.index_unique_excitation[1]])
[docs] def operator_unique_single_double(self, mo, bop, op_squared): """Compute the operator value of the unique single/double conformation Args: mo ([type]): [description] bkin ([type]): [description] op_squared (bool) return the trace of the square of the product """ nbatch = mo.shape[0] if not hasattr(self.exc_mask, 'index_unique_single_up'): self.exc_mask.get_index_unique_single() if not hasattr(self.exc_mask, 'index_unique_double_up'): self.exc_mask.get_index_unique_double() do_single = len(self.exc_mask.index_unique_single_up) != 0 do_double = len(self.exc_mask.index_unique_double_up) != 0 # occupied orbital matrix + det and inv on spin up Aocc_up = mo[:, :self.nup, :self.nup] # occupied orbital matrix + det and inv on spin down Aocc_down = mo[:, self.nup:, :self.ndown] # inverse of the invAup = torch.inverse(Aocc_up) invAdown = torch.inverse(Aocc_down) # precompute invA @ B invAB_up = invAup @ bop[..., :self.nup, :self.nup] invAB_down = invAdown @ bop[..., self.nup:, :self.ndown] # ground state operator if op_squared: op_ground_up = btrace(invAB_up@invAB_up) op_ground_down = btrace(invAB_down@invAB_down) else: op_ground_up = btrace(invAB_up) op_ground_down = btrace(invAB_down) op_ground_up.unsqueeze_(-1) op_ground_down.unsqueeze_(-1) # store the kin terms we need op_out_up = op_ground_up.clone() op_out_down = op_ground_down.clone() # virtual orbital matrices spin up/down Avirt_up = mo[:, :self.nup, self.nup:self.index_max_orb_up] Avirt_down = mo[:, self.nup:, self.ndown: self.index_max_orb_down] # compute the products of invA and Btilde mat_exc_up = (invAup @ Avirt_up) mat_exc_down = (invAdown @ Avirt_down) bop_up = bop[..., :self.nup, :self.index_max_orb_up] bop_occ_up = bop[..., :self.nup, :self.nup] bop_virt_up = bop[..., :self.nup, self.nup:self.index_max_orb_up] bop_down = bop[:, self.nup:, :self.index_max_orb_down] bop_occ_down = bop[..., self.nup:, :self.ndown] bop_virt_down = bop[..., self.nup:, self.ndown:self.index_max_orb_down] Mup = invAup @ bop_virt_up - invAup @ bop_occ_up @ invAup @ Avirt_up Mdown = invAdown @ bop_virt_down - \ invAdown @ bop_occ_down @ invAdown @ Avirt_down # if we only want the normal value of the op and not its squared if not op_squared: # reshape the M matrices Mup = Mup.view(*Mup.shape[:-2], -1) Mdown = Mdown.view(*Mdown.shape[:-2], -1) if do_single: # spin up op_sin_up = self.op_single(op_ground_up, mat_exc_up, Mup, self.exc_mask.index_unique_single_up, nbatch) # spin down op_sin_down = self.op_single(op_ground_down, mat_exc_down, Mdown, self.exc_mask.index_unique_single_down, nbatch) # store the terms we need op_out_up = torch.cat((op_out_up, op_sin_up), dim=-1) op_out_down = torch.cat( (op_out_down, op_sin_down), dim=-1) if do_double: # spin up op_dbl_up = self.op_multiexcitation(op_ground_up, mat_exc_up, Mup, self.exc_mask.index_unique_double_up, 2, nbatch) # spin down op_dbl_down = self.op_multiexcitation(op_ground_down, mat_exc_down, Mdown, self.exc_mask.index_unique_double_down, 2, nbatch) # store the terms we need op_out_up = torch.cat((op_out_up, op_dbl_up), dim=-1) op_out_down = torch.cat( (op_out_down, op_dbl_down), dim=-1) return op_out_up, op_out_down # if we watn the squre of the operatore # typically trace(ABAB) else: # compute A^-1 B M Yup = invAB_up @ Mup Ydown = invAB_down @ Mdown # reshape the M matrices Mup = Mup.view(*Mup.shape[:-2], -1) Mdown = Mdown.view(*Mdown.shape[:-2], -1) # reshape the Y matrices Yup = Yup.view(*Yup.shape[:-2], -1) Ydown = Ydown.view(*Ydown.shape[:-2], -1) if do_single: # spin up op_sin_up = self.op_squared_single(op_ground_up, mat_exc_up, Mup, Yup, self.exc_mask.index_unique_single_up, nbatch) # spin down op_sin_down = self.op_squared_single(op_ground_down, mat_exc_down, Mdown, Ydown, self.exc_mask.index_unique_single_down, nbatch) # store the terms we need op_out_up = torch.cat((op_out_up, op_sin_up), dim=-1) op_out_down = torch.cat( (op_out_down, op_sin_down), dim=-1) if do_double: # spin up values op_dbl_up = self.op_squared_multiexcitation(op_ground_up, mat_exc_up, Mup, Yup, self.exc_mask.index_unique_double_down, 2, nbatch) # spin down values op_dbl_down = self.op_squared_multiexcitation(op_ground_down, mat_exc_down, Mdown, Ydown, self.exc_mask.index_unique_double_down, 2, nbatch) # store the terms we need op_out_up = torch.cat((op_out_up, op_dbl_up), dim=-1) op_out_down = torch.cat( (op_out_down, op_dbl_down), dim=-1) return op_out_up, op_out_down
[docs] @staticmethod def op_single(baseterm, mat_exc, M, index, nbatch): r"""Computes the operator values for single excitation .. math:: Tr( \bar{A}^{-1} \bar{B}) = Tr(A^{-1} B) + Tr( T M ) T = P ( A^{-1} \bar{A})^{-1} P M = A^{-1}\bar{B} - A^{-1}BA^{-1}\bar{A} Args: baseterm (torch.tensor): trace(A B) mat_exc (torch.tensor): invA @ Abar M (torch.tensor): invA Bbar - inv A B inv A Abar index(List): list of index of the excitations nbatch : batch size """ # compute the values of T T = (1. / mat_exc.view(nbatch, -1)[:, index]) # computes trace(T M) op_vals = T * M[..., index] # add the base terms op_vals += baseterm return op_vals
[docs] @staticmethod def op_multiexcitation(baseterm, mat_exc, M, index, size, nbatch): r"""Computes the operator values for single excitation .. math:: Tr( \bar{A}^{-1} \bar{B}) = Tr(A^{-1} B) + Tr( T M ) T = P ( A^{-1} \bar{A})^{-1} P M = A^{-1}\bar{B} - A^{-1}BA^{-1}\bar{A} Args: baseterm (torch.tensor): trace(A B) mat_exc (torch.tensor): invA @ Abar M (torch.tensor): invA Bbar - inv A B inv A Abar index(List): list of index of the excitations size(int) : number of excitation nbatch : batch size """ # get the values of the excitation matrix invA Abar T = mat_exc.view(nbatch, -1)[:, index] # get the shapes of the size x size matrices _ext_shape = (*T.shape[:-1], -1, size, size) _m_shape = (*M.shape[:-1], -1, size, size) # computes the inverse of invA Abar T = torch.inverse(T.view(_ext_shape)) # computes T @ M (after reshaping M as size x size matrices) op_vals = T @ (M[..., index]).view(_m_shape) # compute the trace op_vals = btrace(op_vals) # add the base term op_vals += baseterm return op_vals
[docs] @staticmethod def op_squared_single(baseterm, mat_exc, M, Y, index, nbatch): r"""Computes the operator squared for single excitation .. math:: Tr( (\bar{A}^{-1} \bar{B})^2) = Tr((A^{-1} B)^2) + Tr( (T M)^2 ) + 2 Tr(T Y) T = P ( A^{-1} \bar{A})^{-1} P -> mat_exc in the code M = A^{-1}\bar{B} - A^{-1}BA^{-1}\bar{A} Y = A^{-1} B M Args: baseterm (torch.tensor): trace(A B A B) mat_exc (torch.tensor): invA @ Abar M (torch.tensor): invA Bbar - inv A B inv A Abar Y (torch.tensor): invA B M index(List): list of index of the excitations nbatch : batch size """ # get the values of the inverse excitation matrix T = 1. / (mat_exc.view(nbatch, -1)[:, index]) # compute trace(( T M )^2) tmp = (T * M[..., index]) op_vals = tmp*tmp # trace(T Y) tmp = (T * Y[..., index]) op_vals += 2 * tmp # add the base term op_vals += baseterm return op_vals
[docs] @staticmethod def op_squared_multiexcitation(baseterm, mat_exc, M, Y, index, size, nbatch): r"""Computes the operator squared for multiple excitation .. math:: Tr( (\bar{A}^{-1} \bar{B})^2) = Tr((A^{-1} B)^2) + Tr( (T M)^2 ) + 2 Tr(T Y) T = P ( A^{-1} \bar{A})^{-1} P -> mat_exc in the code M = A^{-1}\bar{B} - A^{-1}BA^{-1}\bar{A} Y = A^{-1} B M Args: baseterm (torch.tensor): trace(A B A B) mat_exc (torch.tensor): invA @ Abar M (torch.tensor): invA Bbar - inv A B inv A Abar Y (torch.tensor): invA B M index(List): list of index of the excitations nbatch : batch size size(int): number of excitation """ # get the values of the excitation matrix invA Abar T = mat_exc.view(nbatch, -1)[:, index] # get the shape as a series of size x size matrices _ext_shape = (*T.shape[:-1], -1, size, size) _m_shape = (*M.shape[:-1], -1, size, size) _y_shape = (*Y.shape[:-1], -1, size, size) # reshape T and take the inverse of the matrices T = torch.inverse(T.view(_ext_shape)) # compute trace(( T M )^2) tmp = T @ (M[..., index]).view(_m_shape) # take the trace of that and add to base value tmp = btrace(tmp @ tmp) op_vals = tmp # compute trace( T Y ) tmp = T @ (Y[..., index]).view(_y_shape) tmp = btrace(tmp) op_vals += 2*tmp # add the base term op_vals += baseterm return op_vals