Source code for qmctorch.solver.solver_mpi

from time import time
from types import SimpleNamespace

import torch
from qmctorch.utils import DataLoader, Loss, OrthoReg, add_group_attr, dump_to_hdf5

from .. import log
from .solver import Solver

try:
    import horovod.torch as hvd
except ModuleNotFoundError:
    pass


[docs] def logd(rank, *args): if rank == 0: log.info(*args)
[docs] class SolverMPI(Solver): def __init__( # pylint: disable=too-many-arguments self, wf=None, sampler=None, optimizer=None, scheduler=None, output=None, rank=0 ): """Distributed QMC solver Args: wf (qmctorch.WaveFunction, optional): wave function. Defaults to None. sampler (qmctorch.sampler, optional): Sampler. Defaults to None. optimizer (torch.optim, optional): optimizer. Defaults to None. scheduler (torch.optim, optional): scheduler. Defaults to None. output (str, optional): hdf5 filename. Defaults to None. rank (int, optional): rank of he process. Defaults to 0. """ super().__init__(wf, sampler, optimizer, scheduler, output, rank) hvd.broadcast_optimizer_state(self.opt, root_rank=0) self.opt = hvd.DistributedOptimizer( self.opt, named_parameters=self.wf.named_parameters() ) self.sampler.walkers.nwalkers //= hvd.size()
[docs] def run( # pylint: disable=too-many-arguments self, nepoch, batchsize=None, loss="energy", clip_loss=False, grad="manual", hdf5_group="wf_opt", num_threads=1, chkpt_every=None, ): """Run the optimization Args: nepoch (int): Number of optimization step batchsize (int, optional): Number of sample in a mini batch. If None, all samples are used. Defaults to None. loss (str, optional): method to compute the loss: variance or energy. Defaults to 'energy'. clip_loss (bool, optional): Clip the loss values at +/- 5std. Defaults to False. grad (str, optional): method to compute the gradients: 'auto' or 'manual'. Defaults to 'auto'. hdf5_group (str, optional): name of the hdf5 group where to store the data. Defaults to 'wf_opt' """ logd(hvd.rank(), "") logd( hvd.rank(), " Distributed Optimization on {num} process".format(num=hvd.size()), ) log.info( " - Process {id} using {nw} walkers".format( id=hvd.rank(), nw=self.sampler.walkers.nwalkers ) ) # observable if not hasattr(self, "observable"): self.track_observable(["local_energy"]) self.evaluate_gradient = { "auto": self.evaluate_grad_auto, "manual": self.evaluate_grad_manual, }[grad] if "lpos_needed" not in self.opt.__dict__.keys(): self.opt.lpos_needed = False self.wf.train() hvd.broadcast_parameters(self.wf.state_dict(), root_rank=0) torch.set_num_threads(num_threads) # get the loss self.loss = Loss(self.wf, method=loss, clip=clip_loss) self.loss.use_weight = self.resampling_options.resample_every > 1 # orthogonalization penalty for the MO coeffs self.ortho_loss = OrthoReg() self.prepare_optimization(batchsize, chkpt_every) # log data if hvd.rank() == 0: self.log_data_opt(nepoch, "wave function optimization") # sample the wave function if hvd.rank() == 0: pos = self.sampler(self.wf.pdf) else: pos = self.sampler(self.wf.pdf, with_tqdm=False) # requried to build the distributed data container pos.requires_grad_(False) # handle the batch size if batchsize is None: batchsize = len(pos) # get the initial observable if hvd.rank() == 0: self.store_observable(pos) # change the number of steps/walker size _nstep_save = self.sampler.nstep _ntherm_save = self.sampler.ntherm _nwalker_save = self.sampler.walkers.nwalkers if self.resampling_options.mode == "update": self.sampler.ntherm = -1 self.sampler.nstep = self.resampling_options.nstep_update self.sampler.walkers.nwalkers = pos.shape[0] # create the data loader # self.dataset = DataSet(pos) self.dataloader = DataLoader(pos, batch_size=batchsize, pin_memory=self.cuda) min_loss = 1e3 for n in range(nepoch): tstart = time() logd(hvd.rank(), "") logd(hvd.rank(), " epoch %d" % n) cumulative_loss = 0.0 for ibatch, data in enumerate(self.dataloader): # get data lpos = data.to(self.device) lpos.requires_grad = True # get the gradient loss, eloc = self.evaluate_gradient(lpos) cumulative_loss += loss # optimize the parameters self.optimization_step(lpos) # observable if hvd.rank() == 0: self.store_observable(pos, local_energy=eloc, ibatch=ibatch) cumulative_loss = self.metric_average(cumulative_loss, "cum_loss") if hvd.rank() == 0: if n == 0 or cumulative_loss < min_loss: self.observable.models.best = dict(self.wf.state_dict()) min_loss = cumulative_loss if self.chkpt_every is not None: if (n > 0) and (n % chkpt_every == 0): self.save_checkpoint(n, cumulative_loss) self.print_observable(cumulative_loss) # resample the data pos = self.resample(n, pos) pos.requires_grad = False # scheduler step if self.scheduler is not None: self.scheduler.step() logd(hvd.rank(), " epoch done in %1.2f sec." % (time() - tstart)) # restore the sampler number of step self.sampler.nstep = _nstep_save self.sampler.ntherm = _ntherm_save self.sampler.walkers.nwalkers = _nwalker_save if hvd.rank() == 0: dump_to_hdf5(self.observable, self.hdf5file, hdf5_group) add_group_attr(self.hdf5file, hdf5_group, {"type": "opt"}) return self.observable
[docs] def single_point(self, with_tqdm=True, batchsize=None, hdf5_group="single_point"): """Performs a single point calculation Args: with_tqdm (bool, optional): use tqdm for samplig. Defaults to True. hdf5_group (str, optional): hdf5 group where to store the data. Defaults to 'single_point'. Returns: SimpleNamespace: contains the local energy, positions, ... """ logd(hvd.rank(), "") logd( hvd.rank(), " Single Point Calculation : {nw} walkers | {ns} steps".format( nw=self.sampler.walkers.nwalkers, ns=self.sampler.nstep ), ) if batchsize is not None: log.info(" Batchsize not supported for MPI solver") # check if we have to compute and store the grads grad_mode = torch.no_grad() if self.wf.kinetic == "auto": grad_mode = torch.enable_grad() # distribute the calculation num_threads = 1 hvd.broadcast_parameters(self.wf.state_dict(), root_rank=0) torch.set_num_threads(num_threads) with grad_mode: # sample the wave function pos = self.sampler(self.wf.pdf, with_tqdm=with_tqdm) if self.wf.cuda and pos.device.type == "cpu": pos = pos.to(self.device) # compute energy/variance/error eloc = self.wf.local_energy(pos) e, s, err = torch.mean(eloc), torch.var(eloc), self.wf.sampling_error(eloc) # gather all data eloc_all = hvd.allgather(eloc, name="local_energies") e, s, err = ( torch.mean(eloc_all), torch.var(eloc_all), self.wf.sampling_error(eloc_all), ) # print if hvd.rank() == 0: log.options(style="percent").info( " Energy : %f +/- %f" % (e.detach().item(), err.detach().item()) ) log.options(style="percent").info(" Variance : %f" % s.detach().item()) # dump data to hdf5 obs = SimpleNamespace( pos=pos, local_energy=eloc_all, energy=e, variance=s, error=err ) # dump to file if hvd.rank() == 0: dump_to_hdf5(obs, self.hdf5file, root_name=hdf5_group) add_group_attr(self.hdf5file, hdf5_group, {"type": "single_point"}) return obs
[docs] @staticmethod def metric_average(val, name): """Average a give quantity over all processes Arguments: val {torch.tensor} -- data to average name {str} -- name of the data Returns: torch.tensor -- Averaged quantity """ tensor = val.clone().detach() avg_tensor = hvd.allreduce(tensor, name=name) return avg_tensor.item()