from copy import deepcopy
from time import time
import torch
from qmctorch.utils import Loss, OrthoReg, add_group_attr, dump_to_hdf5, DataLoader
from .. import log
from .solver_base import SolverBase
[docs]
class Solver(SolverBase):
def __init__( # pylint: disable=too-many-arguments
self, wf=None, sampler=None, optimizer=None, scheduler=None, output=None, rank=0
):
"""Basic QMC solver
Args:
wf (qmctorch.WaveFunction, optional): wave function. Defaults to None.
sampler (qmctorch.sampler, optional): Sampler. Defaults to None.
optimizer (torch.optim, optional): optimizer. Defaults to None.
scheduler (torch.optim, optional): scheduler. Defaults to None.
output (str, optional): hdf5 filename. Defaults to None.
rank (int, optional): rank of he process. Defaults to 0.
"""
SolverBase.__init__(self, wf, sampler, optimizer, scheduler, output, rank)
self.set_params_requires_grad()
self.configure(
track=["local_energy"],
freeze=None,
loss="energy",
grad="manual",
ortho_mo=False,
clip_loss=False,
resampling={"mode": "update", "resample_every": 1, "nstep_update": 25},
)
[docs]
def set_params_requires_grad(self, wf_params=True, geo_params=False):
"""Configure parameters for wf opt."""
# opt all wf parameters
self.wf.ao.bas_exp.requires_grad = wf_params
self.wf.ao.bas_coeffs.requires_grad = wf_params
for param in self.wf.mo.parameters():
param.requires_grad = wf_params
self.wf.fc.weight.requires_grad = wf_params
if hasattr(self.wf, "jastrow"):
if self.wf.jastrow is not None:
for param in self.wf.jastrow.parameters():
param.requires_grad = wf_params
# no opt the atom positions
self.wf.ao.atom_coords.requires_grad = geo_params
[docs]
def freeze_parameters(self, freeze):
"""Freeze the optimization of specified params.
Args:
freeze (list): list of param to freeze
"""
if freeze is not None:
if not isinstance(freeze, list):
freeze = [freeze]
for name in freeze:
if name.lower() == "ci":
self.wf.fc.weight.requires_grad = False
elif name.lower() == "mo":
for param in self.wf.mo.parameters():
param.requires_grad = False
elif name.lower() == "ao":
self.wf.ao.bas_exp.requires_grad = False
self.wf.ao.bas_coeffs.requires_grad = False
elif name.lower() == "jastrow":
for param in self.wf.jastrow.parameters():
param.requires_grad = False
else:
opt_freeze = ["ci", "mo", "ao", "jastrow"]
raise ValueError("Valid arguments for freeze are :", opt_freeze)
[docs]
def save_sampling_parameters(self):
"""save the sampling params."""
self.sampler._nstep_save = self.sampler.nstep
self.sampler._ntherm_save = self.sampler.ntherm
# self.sampler._nwalker_save = self.sampler.walkers.nwalkers
if self.resampling_options.mode == "update":
self.sampler.ntherm = self.resampling_options.ntherm_update
self.sampler.nstep = self.resampling_options.nstep_update
# self.sampler.walkers.nwalkers = pos.shape[0]
[docs]
def restore_sampling_parameters(self):
"""restore sampling params to their original values."""
self.sampler.nstep = self.sampler._nstep_save
self.sampler.ntherm = self.sampler._ntherm_save
# self.sampler.walkers.nwalkers = self.sampler._nwalker_save
[docs]
def geo_opt( # pylint: disable=too-many-arguments
self,
nepoch,
geo_lr=1e-2,
batchsize=None,
nepoch_wf_init=100,
nepoch_wf_update=50,
hdf5_group="geo_opt",
chkpt_every=None,
tqdm=False,
):
"""optimize the geometry of the molecule
Args:
nepoch (int): Number of optimziation step
batchsize (int, optional): Number of sample in a mini batch.
If None, all samples are used.
Defaults to Never.
hdf5_group (str, optional): name of the hdf5 group where to store the data.
Defaults to 'geo_opt'.
chkpt_every (int, optional): save a checkpoint every every iteration.
Defaults to half the number of epoch
"""
# save the optimizer used for the wf params
opt_wf = deepcopy(self.opt)
opt_wf.lpos_needed = self.opt.lpos_needed
# create the optmizier for the geo opt
opt_geo = torch.optim.SGD(self.wf.parameters(), lr=geo_lr)
opt_geo.lpos_needed = False
# save the grad method
eval_grad_wf = self.evaluate_gradient
# log data
self.prepare_optimization(batchsize, None, tqdm)
self.log_data_opt(nepoch, "geometry optimization")
# init the traj
xyz = [self.wf.geometry(None)]
# initial wf optimization
self.set_params_requires_grad(wf_params=True, geo_params=False)
self.freeze_parameters(self.freeze_params_list)
self.run_epochs(nepoch_wf_init)
# iterations over geo optim
for n in range(nepoch):
# make one step geo optim
self.set_params_requires_grad(wf_params=False, geo_params=True)
self.opt = opt_geo
self.evaluate_gradient = self.evaluate_grad_auto
self.run_epochs(1)
xyz.append(self.wf.geometry(None))
# make a few wf optim
self.set_params_requires_grad(wf_params=True, geo_params=False)
self.freeze_parameters(self.freeze_params_list)
self.opt = opt_wf
self.evaluate_gradient = eval_grad_wf
cumulative_loss = self.run_epochs(nepoch_wf_update)
# save checkpoint file
if chkpt_every is not None:
if (n > 0) and (n % chkpt_every == 0):
self.save_checkpoint(n, cumulative_loss)
# restore the sampler number of step
self.restore_sampling_parameters()
# dump
self.observable.geometry = xyz
self.save_data(hdf5_group)
return self.observable
[docs]
def run(
self, nepoch, batchsize=None, hdf5_group="wf_opt", chkpt_every=None, tqdm=False
):
"""Run a wave function optimization
Args:
nepoch (int): Number of optimziation step
batchsize (int, optional): Number of sample in a mini batch.
If None, all samples are used.
Defaults to Never.
hdf5_group (str, optional): name of the hdf5 group where to store the data.
Defaults to 'wf_opt'.
chkpt_every (int, optional): save a checkpoint every every iteration.
Defaults to half the number of epoch
"""
# prepare the optimization
self.prepare_optimization(batchsize, chkpt_every, tqdm)
self.log_data_opt(nepoch, "wave function optimization")
# run the epochs
self.run_epochs(nepoch)
# restore the sampler number of step
self.restore_sampling_parameters()
# dump
self.save_data(hdf5_group)
return self.observable
[docs]
def prepare_optimization(self, batchsize, chkpt_every, tqdm=False):
"""Prepare the optimization process
Args:
batchsize (int or None): batchsize
chkpt_every (int or none): save a chkpt file every
"""
# sample the wave function
pos = self.sampler(self.wf.pdf, with_tqdm=tqdm)
# handle the batch size
if batchsize is None:
batchsize = len(pos)
# change the number of steps/walker size
self.save_sampling_parameters()
# create the data loader
self.dataloader = DataLoader(pos, batch_size=batchsize, pin_memory=self.cuda)
for ibatch, data in enumerate(self.dataloader):
self.store_observable(data, ibatch=ibatch)
# chkpt
self.chkpt_every = chkpt_every
[docs]
def save_data(self, hdf5_group):
"""Save the data to hdf5.
Args:
hdf5_group (str): name of group in the hdf5 file
"""
self.observable.models.last = dict(self.wf.state_dict())
hdf5_group = dump_to_hdf5(self.observable, self.hdf5file, hdf5_group)
add_group_attr(self.hdf5file, hdf5_group, {"type": "opt"})
[docs]
def run_epochs(self, nepoch):
"""Run a certain number of epochs
Args:
nepoch (int): number of epoch to run
"""
# init the loss in case we have nepoch=0
cumulative_loss = 0
min_loss = 0 # this is set at n=0
# loop over the epoch
for n in range(nepoch):
tstart = time()
log.info("")
log.info(
" epoch %d | %d sampling points" % (n, len(self.dataloader.dataset))
)
cumulative_loss = 0
self.opt.zero_grad()
# loop over the batches
for ibatch, data in enumerate(self.dataloader):
# port data to device
lpos = data.to(self.device)
# get the gradient
loss, eloc = self.evaluate_gradient(lpos)
cumulative_loss += loss
# check for nan
if torch.isnan(eloc).any():
log.info("Error : Nan detected in local energy")
return cumulative_loss
# observable
self.store_observable(lpos, local_energy=eloc, ibatch=ibatch)
# optimize the parameters
self.optimization_step(lpos)
# save the model if necessary
if n == 0 or cumulative_loss < min_loss:
min_loss = cumulative_loss
self.observable.models.best = dict(self.wf.state_dict())
# save checkpoint file
if self.chkpt_every is not None:
if (n > 0) and (n % self.chkpt_every == 0):
self.save_checkpoint(n, cumulative_loss)
self.print_observable(cumulative_loss, verbose=False)
# resample the data
self.dataloader.dataset = self.resample(n, self.dataloader.dataset)
# scheduler step
if self.scheduler is not None:
self.scheduler.step()
log.info(" epoch done in %1.2f sec." % (time() - tstart))
return cumulative_loss
[docs]
def evaluate_grad_auto(self, lpos):
"""Evaluate the gradient using automatic differentiation
Args:
lpos (torch.tensor): sampling points
Returns:
tuple: loss values and local energies
"""
# compute the loss
loss, eloc = self.loss(lpos)
# add mo orthogonalization if required
if self.wf.mo.weight.requires_grad and self.ortho_mo:
loss += self.ortho_loss(self.wf.mo.weight)
# compute local gradients
loss.backward()
return loss, eloc
[docs]
def evaluate_grad_manual(self, lpos):
"""Evaluate the gradient using low variance expression
Args:
lpos ([type]): [description]
Args:
lpos (torch.tensor): sampling points
Returns:
tuple: loss values and local energies
"""
# determine if we need the grad of eloc
no_grad_eloc = True
if self.wf.kinetic_method == "auto":
no_grad_eloc = False
if self.wf.jastrow.requires_autograd:
no_grad_eloc = False
if self.loss.method in ["energy", "weighted-energy"]:
# Get the gradient of the total energy
# dE/dk = < (dpsi/dk)/psi (E_L - <E_L >) >
# compute local energy and wf values
_, eloc = self.loss(lpos, no_grad=no_grad_eloc)
psi = self.wf(lpos)
norm = 1.0 / len(psi)
# evaluate the prefactor of the grads
weight = eloc.clone()
weight -= torch.mean(eloc)
weight /= psi
weight *= 2.0
weight *= norm
# compute the gradients
psi.backward(weight)
return torch.mean(eloc), eloc
else:
raise ValueError("Manual gradient only for energy minimization")
[docs]
def log_data_opt(self, nepoch, task):
"""Log data for the optimization."""
log.info("")
log.info(" Optimization")
log.info(" Task :", task)
log.info(" Number Parameters : {0}", self.wf.get_number_parameters())
log.info(" Number of epoch : {0}", nepoch)
log.info(" Batch size : {0}", self.sampler.get_sampling_size())
log.info(" Loss function : {0}", self.loss.method)
log.info(" Clip Loss : {0}", self.loss.clip)
log.info(" Gradients : {0}", self.grad_method)
log.info(" Resampling mode : {0}", self.resampling_options.mode)
log.info(" Resampling every : {0}", self.resampling_options.resample_every)
log.info(" Resampling steps : {0}", self.resampling_options.nstep_update)
log.info(" Output file : {0}", self.hdf5file)
log.info(" Checkpoint every : {0}", self.chkpt_every)
log.info("")