Source code for qmctorch.solver.solver_base

from types import SimpleNamespace
import os
import numpy as np
import torch
from tqdm import tqdm

from .. import log
from ..utils import add_group_attr, dump_to_hdf5

[docs]class SolverBase: def __init__(self, wf=None, sampler=None, optimizer=None, scheduler=None, output=None, rank=0): """Base Class for QMC solver Args: wf (qmctorch.WaveFunction, optional): wave function. Defaults to None. sampler (qmctorch.sampler, optional): Sampler. Defaults to None. optimizer (torch.optim, optional): optimizer. Defaults to None. scheduler (torch.optim, optional): scheduler. Defaults to None. output (str, optional): hdf5 filename. Defaults to None. rank (int, optional): rank of he process. Defaults to 0. """ = wf self.sampler = sampler self.opt = optimizer self.scheduler = scheduler self.cuda = False self.device = torch.device('cpu') # member defined in the child and or method self.dataloader = None self.loss = None self.obs_dict = None # if pos are needed for the optimizer (obsolete ?) if self.opt is not None and 'lpos_needed' not in self.opt.__dict__.keys(): self.opt.lpos_needed = False # distributed model self.save_model = 'model.pth' # handles GPU availability if self.device = torch.device('cuda') self.sampler.cuda = True self.sampler.walkers.cuda = True else: self.device = torch.device('cpu') self.hdf5file = output if output is None: basename = os.path.basename('.')[0] self.hdf5file = basename + '_QMCTorch.hdf5' if rank == 0: dump_to_hdf5(self, self.hdf5file) self.log_data()
[docs] def configure_resampling(self, mode='update', resample_every=1, nstep_update=25): """Configure the resampling Args: mode (str, optional): method to resample : 'full', 'update', 'never' Defaults to 'update'. resample_every (int, optional): Number of optimization steps between resampling Defaults to 1. nstep_update (int, optional): Number of MC steps in update mode. Defaults to 25. """ self.resampling_options = SimpleNamespace() valid_mode = ['never', 'full', 'update'] if mode not in valid_mode: raise ValueError( mode, 'not a valid update method : ', valid_mode) self.resampling_options.mode = mode self.resampling_options.resample_every = resample_every self.resampling_options.nstep_update = nstep_update
[docs] def track_observable(self, obs_name): """define the observalbe we want to track Args: obs_name (list): list of str defining the observalbe. Each str must correspond to a WaveFuncion method """ # make sure it's a list if not isinstance(obs_name, list): obs_name = list(obs_name) # sanity check valid_obs_name = ['energy', 'local_energy', 'geometry', 'parameters', 'gradients'] for name in obs_name: if name in valid_obs_name: continue elif hasattr(, name): continue else: ' Error : Observable %s not recognized' % name)' : Possible observable') for n in valid_obs_name:' : - %s' % n) ' : - or any method of the wave function') raise ValueError('Observable not recognized') # reset the Namesapce self.observable = SimpleNamespace() # add the energy of the sytem if 'energy' not in obs_name: obs_name += ['energy'] # add the geometry of the system if 'geometry' not in obs_name: obs_name += ['geometry'] for k in obs_name: if k == 'parameters': for key, p in zip(, if p.requires_grad: self.observable.__setattr__(key, []) elif k == 'gradients': for key, p in zip(, if p.requires_grad: self.observable.__setattr__(key+'.grad', []) else: self.observable.__setattr__(k, []) self.observable.models = SimpleNamespace()
[docs] def store_observable(self, pos, local_energy=None, ibatch=None, **kwargs): """store observale in the dictionary Args: obs_dict (dict): dictionary of the observalbe pos (torch.tensor): positions of th walkers local_energy (torch.tensor, optional): precomputed values of the local energy. Defaults to None ibatch (int): index of the current batch. Defaults to None """ if and pos.device.type == 'cpu': pos = for obs in self.observable.__dict__.keys(): # store the energy if obs == 'energy' and local_energy is not None: data = local_energy.cpu().detach().numpy() if (ibatch is None) or (ibatch == 0): else:[-1] *= ibatch/(ibatch+1)[-1] += np.mean( data)/(ibatch+1) # store local energy elif obs == 'local_energy' and local_energy is not None: data = local_energy.cpu().detach().numpy() if (ibatch is None) or (ibatch == 0): self.observable.local_energy.append(data) else: self.observable.local_energy[-1] = np.append( self.observable.local_energy[-1], data) # store variational parameter elif obs in p =[obs].clone() self.observable.__getattribute__( obs).append( if obs+'.grad' in self.observable.__dict__.keys(): if p.grad is not None: self.observable.__getattribute__(obs + '.grad').append(p.grad.cpu().numpy()) else: self.observable.__getattribute__(obs + '.grad').append(torch.zeros_like( # store any other defined method elif hasattr(, obs): func = data = func(pos) if isinstance(data, torch.Tensor): data = data.cpu().detach().numpy() if isinstance(data, list): data = np.array(data) if (ibatch is None) or (ibatch == 0): self.observable.__getattribute__( obs).append(data) else: self.observable.__getattribute__( obs)[-1] = np.append(self.observable.__getattribute__( obs)[-1], data)
[docs] def print_observable(self, cumulative_loss, verbose=False): """Print the observalbe to csreen Args: cumulative_loss (float): current loss value verbose (bool, optional): print all the observables. Defaults to False """ for k in self.observable.__dict__.keys(): if k == 'local_energy': eloc = self.observable.local_energy[-1] e = np.mean(eloc) v = np.var(eloc) err = np.sqrt(v / len(eloc)) log.options(style='percent').info( ' energy : %f +/- %f' % (e, err)) log.options(style='percent').info( ' variance : %f' % np.sqrt(v)) elif verbose: log.options(style='percent').info( k + ' : ', self.observable.__getattribute__(k)[-1]) log.options(style='percent').info( 'loss %f' % (cumulative_loss))
[docs] def resample(self, n, pos): """Resample the wave function Args: n (int): current epoch value pos (torch.tensor): positions of the walkers Returns: (torch.tensor): new positions of the walkers """ if self.resampling_options.mode != 'never': # resample the data if (n % self.resampling_options.resample_every == 0): # make a copy of the pos if we update if self.resampling_options.mode == 'update': pos = pos.clone().detach().to(self.device) # start from scratch otherwise else: pos = None # sample and update the dataset pos = self.sampler(, pos=pos, with_tqdm=False) = pos # update the weight of the loss if needed if self.loss.use_weight: self.loss.weight['psi0'] = None return pos
[docs] def single_point(self, with_tqdm=True, hdf5_group='single_point'): """Performs a single point calculatin Args: with_tqdm (bool, optional): use tqdm for samplig. Defaults to True. hdf5_group (str, optional): hdf5 group where to store the data. Defaults to 'single_point'. Returns: SimpleNamespace: contains the local energy, positions, ... """'')' Single Point Calculation : {nw} walkers | {ns} steps', nw=self.sampler.nwalkers, ns=self.sampler.nstep) # check if we have to compute and store the grads grad_mode = torch.no_grad() if == 'auto': grad_mode = torch.enable_grad() with grad_mode: # get the position and put to gpu if necessary pos = self.sampler(, with_tqdm=with_tqdm) if and pos.device.type == 'cpu': pos = # compute energy/variance/error el = e, s, err = torch.mean(el), torch.var( el), # print data log.options(style='percent').info( ' Energy : %f +/- %f' % (e.detach().item(), err.detach().item())) log.options(style='percent').info( ' Variance : %f' % s.detach().item()) # dump data to hdf5 obs = SimpleNamespace( pos=pos, local_energy=el, energy=e, variance=s, error=err ) dump_to_hdf5(obs, self.hdf5file, root_name=hdf5_group) add_group_attr(self.hdf5file, hdf5_group, {'type': 'single_point'}) return obs
[docs] def save_checkpoint(self, epoch, loss): """save the model and optimizer state Args: epoch (int): epoch loss (float): current value of the loss filename (str): name to save the file """ filename = 'checkpoint_epoch%d.pth' % epoch{ 'epoch': epoch, 'model_state_dict':, 'optimzier_state_dict': self.opt.state_dict(), 'loss': loss }, filename)
[docs] def load_checkpoint(self, filename): """load a model/optmizer Args: filename (str): filename Returns: tuple : epoch number and loss """ data = torch.load(filename)['model_state_dict']) self.opt.load_state_dict(data['optimzier_state_dict']) epoch = data['epoch'] loss = data['loss'] return epoch, loss
def _append_observable(self, key, data): """Append a new data point to observable key. Arguments: key {str} -- name of the observable data {} -- data """ if key not in self.obs_dict.keys(): self.obs_dict[key] = [] self.obs_dict[key].append(data)
[docs] def sampling_traj(self, pos=None, with_tqdm=True, hdf5_group='sampling_trajectory'): """Compute the local energy along a sampling trajectory Args: pos (torch.tensor): positions of the walkers along the trajectory hdf5_group (str, optional): name of the group where to store the data. Defaults to 'sampling_trajecory' Returns: SimpleNamespace : contains energy/positions/ """'')' Sampling trajectory') if pos is None: pos = self.sampler(, with_tqdm=with_tqdm) ndim = pos.shape[-1] p = pos.view(-1, self.sampler.nwalkers, ndim) el = [] rng = tqdm(p, desc='INFO:QMCTorch| Energy ', disable=not with_tqdm) for ip in rng: el.append( el = np.array(el).squeeze(-1) obs = SimpleNamespace(local_energy=np.array(el), pos=pos) dump_to_hdf5(obs, self.hdf5file, hdf5_group) add_group_attr(self.hdf5file, hdf5_group, {'type': 'sampling_traj'}) return obs
[docs] def print_parameters(self, grad=False): """print parameter values Args: grad (bool, optional): also print the gradient. Defaults to False. """ for p in if p.requires_grad: if grad: print(p.grad) else: print(p)
[docs] def optimization_step(self, lpos): """Performs one optimization step Arguments: lpos {torch.tensor} -- positions of the walkers """ if self.opt.lpos_needed: self.opt.step(lpos) else: self.opt.step()
[docs] def save_traj(self, fname, obs): """Save trajectory of geo_opt Args: fname (str): file name """ f = open(fname, 'w') xyz = obs.geometry natom = len(xyz[0]) nm2bohr = 1.88973 for snap in xyz: f.write('%d \n\n' % natom) for i, pos in enumerate(snap): at =[i] f.write('%s % 7.5f % 7.5f %7.5f\n' % (at[0], pos[0] / nm2bohr, pos[1] / nm2bohr, pos[2] / nm2bohr)) f.write('\n') f.close()
[docs] def run(self, nepoch, batchsize=None, loss='variance'): raise NotImplementedError()
[docs] def log_data(self): """Log basic information about the sampler."""'')' QMC Solver ') if is not None: ' WaveFunction : {0}', for x in'\n'): log.debug(' ' + x) if self.sampler is not None: ' Sampler : {0}', self.sampler.__class__.__name__) for x in self.sampler.__repr__().split('\n'): log.debug(' ' + x) if self.opt is not None: ' Optimizer : {0}', self.opt.__class__.__name__) for x in self.opt.__repr__().split('\n'): log.debug(' ' + x) if self.scheduler is not None: ' Scheduler : {0}', self.scheduler.__class__.__name__) for x in self.scheduler.__repr__().split('\n'): log.debug(' ' + x)