Source code for qmctorch.sampler.sampler_base

import torch
from typing import Dict, Callable
from .. import log
from .walkers import Walkers


[docs] class SamplerBase: def __init__( self, nwalkers: int, nstep: int, step_size: float, ntherm: int, ndecor: int, nelec: int, ndim: int, init: Dict, cuda: bool, ) -> None: """Base class for the sampler Args: nwalkers (int): number of walkers nstep (int): number of MC steps step_size (float): size of the steps in bohr ntherm (int): number of MC steps to thermalize ndecor (int): number of MC steps to decorellate nelec (int): number of electrons in the system ndim (int): number of cartesian dimension init (dict): method to initialize the walkers cuda (bool): turn CUDA ON/OFF """ # self.nwalkers = nwalkers self.nelec = nelec self.ndim = ndim self.nstep = nstep self.step_size = step_size self.ntherm = ntherm self.ndecor = ndecor self.cuda = cuda if cuda: self.device = torch.device("cuda") else: self.device = torch.device("cpu") self.walkers = Walkers( nwalkers=nwalkers, nelec=nelec, ndim=ndim, init=init, cuda=cuda, ) log.info("") log.info(" Monte-Carlo Sampler") log.info(" Number of walkers : {0}", self.walkers.nwalkers) log.info(" Number of steps : {0}", self.nstep) log.info(" Step size : {0}", self.step_size) log.info(" Thermalization steps: {0}", self.ntherm) log.info(" Decorelation steps : {0}", self.ndecor) log.info(" Walkers init pos : {0}", init["method"]) def __call__( self, pdf: Callable[[torch.Tensor], torch.Tensor], *args, **kwargs ) -> torch.Tensor: """ Evaluate the sampling algorithm. Args: pdf (Callable[[torch.Tensor], torch.Tensor]): the function to sample *args: additional positional arguments **kwargs: additional keyword arguments Returns: torch.Tensor: the samples """ raise NotImplementedError("Sampler must have a __call__ method") def __repr__(self) -> str: return ( self.__class__.__name__ + " sampler with %d walkers" % self.walkers.nwalkers )
[docs] def get_sampling_size(self) -> int: """evaluate the number of sampling point we'll have.""" if self.ntherm == -1: return self.walkers.nwalkers else: return self.walkers.nwalkers * int((self.nstep - self.ntherm) / self.ndecor)