Source code for qmctorch.wavefunction.jastrows.jastrow_factor_combined_terms

from torch import nn
from functools import reduce

from .elec_elec.jastrow_factor_electron_electron import JastrowFactorElectronElectron
from .elec_nuclei.jastrow_factor_electron_nuclei import JastrowFactorElectronNuclei
from .elec_elec_nuclei.jastrow_factor_electron_electron_nuclei import (
    JastrowFactorElectronElectronNuclei,
)


from .elec_elec.kernels.pade_jastrow_kernel import (
    PadeJastrowKernel as PadeJastrowKernelElecElec,
)
from .elec_nuclei.kernels.pade_jastrow_kernel import (
    PadeJastrowKernel as PadeJastrowKernelElecNuc,
)


[docs] class JastrowFactorCombinedTerms(nn.Module): def __init__( self, mol, jastrow_kernel={ "ee": PadeJastrowKernelElecElec, "en": PadeJastrowKernelElecNuc, "een": None, }, jastrow_kernel_kwargs={"ee": {}, "en": {}, "een": {}}, cuda=False, ): """[summary] Args: nup (int): number of spin up electron ndown (int): number opf spin down electron atomic_pos (torch tensor): atomic positions jastrow_kernel ([dict]): kernels of the jastrow factor jastrow_kernel_kwargs (dict): keyword argument of the kernels cuda (bool, optional): [description]. Defaults to False. """ super().__init__() self.nup = mol.nup self.ndown = mol.ndown self.cuda = cuda self.jastrow_kernel_dict = jastrow_kernel self.jastrow_terms = nn.ModuleList() # sanitize the dict for k in ["ee", "en", "een"]: if k not in jastrow_kernel.keys(): jastrow_kernel[k] = None if k not in jastrow_kernel_kwargs.keys(): jastrow_kernel_kwargs[k] = {} self.requires_autograd = True if jastrow_kernel["ee"] is not None: self.jastrow_terms.append( JastrowFactorElectronElectron( mol, jastrow_kernel["ee"], jastrow_kernel_kwargs["ee"], cuda=cuda ) ) if jastrow_kernel["en"] is not None: self.jastrow_terms.append( JastrowFactorElectronNuclei( mol, jastrow_kernel["en"], jastrow_kernel_kwargs["en"], cuda=cuda ) ) if jastrow_kernel["een"] is not None: self.jastrow_terms.append( JastrowFactorElectronElectronNuclei( mol, jastrow_kernel["een"], jastrow_kernel_kwargs["een"], cuda=cuda ) ) self.nterms = len(self.jastrow_terms) def __repr__(self): """representation of the jastrow factor""" out = [] for k in ["ee", "en", "een"]: if self.jastrow_kernel_dict[k] is not None: out.append(k + " -> " + self.jastrow_kernel_dict[k].__name__) return " + ".join(out)
[docs] def forward(self, pos, derivative=0, sum_grad=True): """Compute the Jastrow factors. Args: pos(torch.tensor): Positions of the electrons Size: Nbatch, Nelec x Ndim derivative (int, optional): order of the derivative (0, 1, 2,). Defaults to 0. sum_grad(bool, optional): Return the sum_grad(i.e. the sum of the derivatives) terms. Defaults to True. False only for derivative = 1 Returns: torch.tensor: value of the jastrow parameter for all confs derivative = 0 (Nmo) x Nbatch x 1 derivative = 1 (Nmo) x Nbatch x Nelec (for sum_grad = True) derivative = 1 (Nmo) x Nbatch x Ndim x Nelec (for sum_grad = False) """ if derivative == 0: jast_vals = [term(pos) for term in self.jastrow_terms] return self.get_combined_values(jast_vals) elif derivative == 1: if sum_grad: jast_vals = [term(pos) for term in self.jastrow_terms] else: jast_vals = [term(pos).unsqueeze(-1) for term in self.jastrow_terms] djast_vals = [ term(pos, derivative=1, sum_grad=sum_grad) for term in self.jastrow_terms ] return self.get_derivative_combined_values(jast_vals, djast_vals) elif derivative == 2: jast_vals = [term(pos) for term in self.jastrow_terms] djast_vals = [ term(pos, derivative=1, sum_grad=False) for term in self.jastrow_terms ] d2jast_vals = [term(pos, derivative=2) for term in self.jastrow_terms] return self.get_second_derivative_combined_values( jast_vals, djast_vals, d2jast_vals ) elif derivative == [0, 1, 2]: jast_vals = [term(pos) for term in self.jastrow_terms] djast_vals = [ term(pos, derivative=1, sum_grad=False) for term in self.jastrow_terms ] d2jast_vals = [term(pos, derivative=2) for term in self.jastrow_terms] # combine the jastrow terms out_jast = self.get_combined_values(jast_vals) # combine the second derivative out_d2jast = self.get_second_derivative_combined_values( jast_vals, djast_vals, d2jast_vals ) # unsqueeze the jast terms to be compatible with the # derivative jast_vals = [j.unsqueeze(-1) for j in jast_vals] # combine the derivative out_djast = self.get_derivative_combined_values(jast_vals, djast_vals) return (out_jast, out_djast, out_d2jast) else: raise ValueError("derivative not understood")
[docs] @staticmethod def get_combined_values(jast_vals): """Compute the product of all terms in jast_vals.""" if len(jast_vals) == 1: return jast_vals[0] else: return reduce(lambda x, y: x * y, jast_vals)
[docs] @staticmethod def get_derivative_combined_values(jast_vals, djast_vals): """Compute the derivative of the product. .. math: J = A * B * C \\frac{d J}{dx} = \\frac{d A}{dx} B C + A \\frac{d B}{dx} C + A B \\frac{d C}{dx} """ if len(djast_vals) == 1: return djast_vals[0] else: out = 0.0 nterms = len(jast_vals) for i in range(nterms): tmp = jast_vals.copy() tmp[i] = djast_vals[i] out += reduce(lambda x, y: x * y, tmp) return out
[docs] @staticmethod def get_second_derivative_combined_values(jast_vals, djast_vals, d2jast_vals): """Compute the derivative of the product. .. math: J = A * B * C \\frac{d^2 J}{dx^2} = \\frac{d^2 A}{dx^2} B C + A \\frac{d^2 B}{dx^2} C + A B \\frac{d^2 C}{dx^2} \\ + 2( \\frac{d A}{dx} \\frac{dB}{dx} C + \\frac{d A}{dx} B \\frac{dC}{dx} + A \\frac{d B}{dx} \\frac{dC}{dx} ) """ if len(d2jast_vals) == 1: return d2jast_vals[0] else: out = 0.0 nterms = len(jast_vals) for i in range(nterms): # d2a * b * c tmp = jast_vals.copy() tmp[i] = d2jast_vals[i] out = out + reduce(lambda x, y: x * y, tmp) for i in range(nterms - 1): for j in range(i + 1, nterms): # da * db * c tmp = jast_vals.copy() tmp = [j.unsqueeze(-1) for j in tmp] tmp[i] = djast_vals[i] tmp[j] = djast_vals[j] out = out + (2.0 * reduce(lambda x, y: x * y, tmp)).sum(1) return out