from time import time
from types import SimpleNamespace
import torch
from torch.utils.data import DataLoader
from qmctorch.utils import (DataSet, Loss, OrthoReg, add_group_attr,
dump_to_hdf5)
from .. import log
from .solver import Solver
try:
import horovod.torch as hvd
except ModuleNotFoundError:
pass
[docs]def logd(rank, *args):
if rank == 0:
log.info(*args)
[docs]class SolverMPI(Solver):
def __init__(self, wf=None, sampler=None, optimizer=None,
scheduler=None, output=None, rank=0):
"""Distributed QMC solver
Args:
wf (qmctorch.WaveFunction, optional): wave function. Defaults to None.
sampler (qmctorch.sampler, optional): Sampler. Defaults to None.
optimizer (torch.optim, optional): optimizer. Defaults to None.
scheduler (torch.optim, optional): scheduler. Defaults to None.
output (str, optional): hdf5 filename. Defaults to None.
rank (int, optional): rank of he process. Defaults to 0.
"""
super().__init__(wf, sampler,
optimizer, scheduler, output, rank)
hvd.broadcast_optimizer_state(self.opt, root_rank=0)
self.opt = hvd.DistributedOptimizer(
self.opt, named_parameters=self.wf.named_parameters())
self.sampler.nwalkers //= hvd.size()
self.sampler.walkers.nwalkers //= hvd.size()
[docs] def run(self, nepoch, batchsize=None, loss='energy',
clip_loss=False, grad='manual', hdf5_group='wf_opt',
num_threads=1, chkpt_every=None):
"""Run the optimization
Args:
nepoch (int): Number of optimization step
batchsize (int, optional): Number of sample in a mini batch.
If None, all samples are used.
Defaults to None.
loss (str, optional): method to compute the loss: variance or energy.
Defaults to 'energy'.
clip_loss (bool, optional): Clip the loss values at +/- 5std.
Defaults to False.
grad (str, optional): method to compute the gradients: 'auto' or 'manual'.
Defaults to 'auto'.
hdf5_group (str, optional): name of the hdf5 group where to store the data.
Defaults to 'wf_opt'
"""
logd(hvd.rank(), '')
logd(hvd.rank(),
' Distributed Optimization on {num} process'.format(num=hvd.size()))
log.info(' - Process {id} using {nw} walkers'.format(
id=hvd.rank(), nw=self.sampler.nwalkers))
# observable
if not hasattr(self, 'observable'):
self.track_observable(['local_energy'])
self.evaluate_gradient = {
'auto': self.evaluate_grad_auto,
'manual': self.evaluate_grad_manual}[grad]
if 'lpos_needed' not in self.opt.__dict__.keys():
self.opt.lpos_needed = False
self.wf.train()
hvd.broadcast_parameters(self.wf.state_dict(), root_rank=0)
torch.set_num_threads(num_threads)
# get the loss
self.loss = Loss(self.wf, method=loss, clip=clip_loss)
self.loss.use_weight = (
self.resampling_options.resample_every > 1)
# orthogonalization penalty for the MO coeffs
self.ortho_loss = OrthoReg()
self.prepare_optimization(batchsize, chkpt_every)
# log data
if hvd.rank() == 0:
self.log_data_opt(nepoch, 'wave function optimization')
# sample the wave function
if hvd.rank() == 0:
pos = self.sampler(self.wf.pdf)
else:
pos = self.sampler(self.wf.pdf, with_tqdm=False)
# requried to build the distributed data container
pos.requires_grad_(False)
# handle the batch size
if batchsize is None:
batchsize = len(pos)
# get the initial observable
if hvd.rank() == 0:
self.store_observable(pos)
# change the number of steps/walker size
_nstep_save = self.sampler.nstep
_ntherm_save = self.sampler.ntherm
_nwalker_save = self.sampler.walkers.nwalkers
if self.resampling_options.mode == 'update':
self.sampler.ntherm = -1
self.sampler.nstep = self.resampling_options.nstep_update
self.sampler.walkers.nwalkers = pos.shape[0]
self.sampler.nwalkers = pos.shape[0]
# create the data loader
self.dataset = DataSet(pos)
if self.cuda:
kwargs = {'num_workers': num_threads, 'pin_memory': True}
else:
kwargs = {'num_workers': num_threads}
self.dataloader = DataLoader(self.dataset,
batch_size=batchsize,
**kwargs)
min_loss = 1E3
for n in range(nepoch):
tstart = time()
logd(hvd.rank(), '')
logd(hvd.rank(), ' epoch %d' % n)
cumulative_loss = 0.
for ibatch, data in enumerate(self.dataloader):
# get data
lpos = data.to(self.device)
lpos.requires_grad = True
# get the gradient
loss, eloc = self.evaluate_gradient(lpos)
cumulative_loss += loss
# optimize the parameters
self.optimization_step(lpos)
# observable
if hvd.rank() == 0:
self.store_observable(
pos, local_energy=eloc, ibatch=ibatch)
cumulative_loss = self.metric_average(cumulative_loss,
'cum_loss')
if hvd.rank() == 0:
if n == 0 or cumulative_loss < min_loss:
self.observable.models.best = dict(
self.wf.state_dict())
min_loss = cumulative_loss
if self.chkpt_every is not None:
if (n > 0) and (n % chkpt_every == 0):
self.save_checkpoint(n, cumulative_loss)
self.print_observable(cumulative_loss)
# resample the data
pos = self.resample(n, pos)
pos.requires_grad = False
# scheduler step
if self.scheduler is not None:
self.scheduler.step()
logd(hvd.rank(), ' epoch done in %1.2f sec.' %
(time()-tstart))
# restore the sampler number of step
self.sampler.nstep = _nstep_save
self.sampler.ntherm = _ntherm_save
self.sampler.walkers.nwalkers = _nwalker_save
self.sampler.nwalkers = _nwalker_save
if hvd.rank() == 0:
dump_to_hdf5(self.observable, self.hdf5file, hdf5_group)
add_group_attr(self.hdf5file, hdf5_group, {'type': 'opt'})
return self.observable
[docs] def single_point(self, with_tqdm=True, hdf5_group='single_point'):
"""Performs a single point calculation
Args:
with_tqdm (bool, optional): use tqdm for samplig. Defaults to True.
hdf5_group (str, optional): hdf5 group where to store the data.
Defaults to 'single_point'.
Returns:
SimpleNamespace: contains the local energy, positions, ...
"""
logd(hvd.rank(), '')
logd(hvd.rank(), ' Single Point Calculation : {nw} walkers | {ns} steps'.format(
nw=self.sampler.nwalkers, ns=self.sampler.nstep))
# check if we have to compute and store the grads
grad_mode = torch.no_grad()
if self.wf.kinetic == 'auto':
grad_mode = torch.enable_grad()
# distribute the calculation
num_threads = 1
hvd.broadcast_parameters(self.wf.state_dict(), root_rank=0)
torch.set_num_threads(num_threads)
with grad_mode:
# sample the wave function
pos = self.sampler(self.wf.pdf)
if self.wf.cuda and pos.device.type == 'cpu':
pos = pos.to(self.device)
# compute energy/variance/error
eloc = self.wf.local_energy(pos)
e, s, err = torch.mean(eloc), torch.var(
eloc), self.wf.sampling_error(eloc)
# gather all data
eloc_all = hvd.allgather(eloc, name='local_energies')
e, s, err = torch.mean(eloc_all), torch.var(
eloc_all), self.wf.sampling_error(eloc_all)
# print
if hvd.rank() == 0:
log.options(style='percent').info(
' Energy : %f +/- %f' % (e.detach().item(), err.detach().item()))
log.options(style='percent').info(
' Variance : %f' % s.detach().item())
# dump data to hdf5
obs = SimpleNamespace(
pos=pos,
local_energy=eloc_all,
energy=e,
variance=s,
error=err
)
# dump to file
if hvd.rank() == 0:
dump_to_hdf5(obs,
self.hdf5file,
root_name=hdf5_group)
add_group_attr(self.hdf5file, hdf5_group,
{'type': 'single_point'})
return obs
[docs] @staticmethod
def metric_average(val, name):
"""Average a give quantity over all processes
Arguments:
val {torch.tensor} -- data to average
name {str} -- name of the data
Returns:
torch.tensor -- Averaged quantity
"""
tensor = val.clone().detach()
avg_tensor = hvd.allreduce(tensor, name=name)
return avg_tensor.item()