Wave Function Optimization

We present here a complete example on how to use QMCTorch on a H2 molecule. We first need to import all the relevant modules :

[1]:
from torch import optim
from qmctorch.scf import Molecule
from qmctorch.wavefunction import SlaterJastrow
from qmctorch.solver import Solver
from qmctorch.sampler import Metropolis
from qmctorch.utils import set_torch_double_precision
from qmctorch.utils.plot_data import plot_energy
set_torch_double_precision()
INFO:QMCTorch|  ____    __  ______________             _
INFO:QMCTorch| / __ \  /  |/  / ___/_  __/__  ________/ /
INFO:QMCTorch|/ /_/ / / /|_/ / /__  / / / _ \/ __/ __/ _ \
INFO:QMCTorch|\___\_\/_/  /_/\___/ /_/  \___/_/  \__/_//_/

Creating the system

The first step is to define a molecule. We here use a H2 molecule with both hydrgen atoms on the z-axis and separated by 1.38 atomic unit. We choose here to use Slater orbitals that can be otained via ADF. We simply here reload calculations to create the molecule

[2]:
mol = Molecule(load='./hdf5/H2_adf_dzp.hdf5')
INFO:QMCTorch|
INFO:QMCTorch| SCF Calculation
INFO:QMCTorch|  Loading data from ./hdf5/H2_adf_dzp.hdf5

We then define the wave function relative to this molecule. We also specify here the determinants we want to use in the CI expansion. We use here a to include all the single and double excitation with 2 electrons and 2 orbitals

[3]:
wf = SlaterJastrow(mol, configs='single_double(2,2)')
INFO:QMCTorch|
INFO:QMCTorch| Wave Function
INFO:QMCTorch|  Jastrow factor      : True
INFO:QMCTorch|  Jastrow kernel      : ee -> PadeJastrowKernel
INFO:QMCTorch|  Highest MO included : 10
INFO:QMCTorch|  Configurations      : single_double(2,2)
INFO:QMCTorch|  Number of confs     : 4
INFO:QMCTorch|  Kinetic energy      : jacobi
INFO:QMCTorch|  Number var  param   : 121
INFO:QMCTorch|  Cuda support        : False

As a sampler we use a simple Metropolis Hasting with 1000 walkers. The walkers are initially localized around the atoms. Each walker will perform 2000 steps of size 0.2 atomic unit and will only keep the last position of each walker (ntherm=-1). During each move all the the electrons are moved simultaneously within a normal distribution centered around their current location.

[5]:
sampler = Metropolis(nwalkers=5000,
                     nstep=200, step_size=0.2,
                     ntherm=-1, ndecor=100,
                     nelec=wf.nelec, init=mol.domain('atomic'),
                     move={'type': 'all-elec', 'proba': 'normal'})
INFO:QMCTorch|
INFO:QMCTorch| Monte-Carlo Sampler
INFO:QMCTorch|  Number of walkers   : 5000
INFO:QMCTorch|  Number of steps     : 200
INFO:QMCTorch|  Step size           : 0.2
INFO:QMCTorch|  Thermalization steps: -1
INFO:QMCTorch|  Decorelation steps  : 100
INFO:QMCTorch|  Walkers init pos    : atomic
INFO:QMCTorch|  Move type           : all-elec
INFO:QMCTorch|  Move proba          : normal

We will use the ADAM optimizer implemented in pytorch with custom learning rate for each layer. We also define a linear scheduler that will decrease the learning rate after 100 steps

[6]:
lr_dict = [{'params': wf.jastrow.parameters(), 'lr': 1E-2},
           {'params': wf.ao.parameters(), 'lr': 1E-6},
           {'params': wf.mo.parameters(), 'lr': 2E-3},
           {'params': wf.fc.parameters(), 'lr': 2E-3}]
opt = optim.Adam(lr_dict, lr=1E-3)

A scheduler can also be used to progressively decrease the value of the learning rate during the optimization.

[7]:
scheduler = optim.lr_scheduler.StepLR(opt, step_size=100, gamma=0.90)

We can now assemble the solver

[8]:
solver = Solver(wf=wf, sampler=sampler, optimizer=opt, scheduler=None)
INFO:QMCTorch|
INFO:QMCTorch| QMC Solver
INFO:QMCTorch|  WaveFunction        : SlaterJastrow
INFO:QMCTorch|  Sampler             : Metropolis
INFO:QMCTorch|  Optimizer           : Adam

Configuring the solver

Many parameters of the optimization can be controlled. We can specify which observale to track during the optimization. Here only the local energies will be recorded but one can also record the variational parameters

[9]:
solver.configure(track=['local_energy', 'parameters'])

Some variational parameters can be frozen and therefore not optimized. We here freeze the MO coefficients and the AO parameters and therefore only the jastrow parametres and the CI coefficients will be optmized

[10]:
solver.configure(freeze=['ao'])

Either the mean or the variance of local energies can be used as a loss function. We choose here to minimize the energy to optimize the wave function

[11]:
solver.configure(loss='energy')

The gradients of the wave function w.r.t. the variational parameters can be computed directly via automatic differntiation (grad='auto') or manually (grad='manual') via a reduced noise formula. We pick here a manual calculation

[12]:
solver.configure(grad='manual')

We also configure the resampling so that the positions of the walkers are updated by performing 25 MC steps from their previous positions after each optimization step.

[13]:
solver.configure(resampling={'mode': 'update',
                            'resample_every': 1,
                            'nstep_update': 25})

Running the wave function optimization

We can now run the optimization. We use here 50 optimization steps (epoch), using all the points in a single mini-batch.

[14]:
obs = solver.run(50)
INFO:QMCTorch|
INFO:QMCTorch|  Optimization
INFO:QMCTorch|  Task                :
INFO:QMCTorch|  Number Parameters   : 115
INFO:QMCTorch|  Number of epoch     : 50
INFO:QMCTorch|  Batch size          : 5000
INFO:QMCTorch|  Loss function       : energy
INFO:QMCTorch|  Clip Loss           : False
INFO:QMCTorch|  Gradients           : manual
INFO:QMCTorch|  Resampling mode     : update
INFO:QMCTorch|  Resampling every    : 1
INFO:QMCTorch|  Resampling steps    : 25
INFO:QMCTorch|  Output file         : H2_adf_dzp_QMCTorch.hdf5
INFO:QMCTorch|  Checkpoint every    : None
INFO:QMCTorch|
INFO:QMCTorch|
INFO:QMCTorch|  epoch 0 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.155820 +/- 0.003248
INFO:QMCTorch|  variance : 0.229678
INFO:QMCTorch|  epoch done in 0.17 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 1 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.154966 +/- 0.003175
INFO:QMCTorch|  variance : 0.224524
INFO:QMCTorch|  epoch done in 0.23 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 2 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.153755 +/- 0.003123
INFO:QMCTorch|  variance : 0.220823
INFO:QMCTorch|  epoch done in 0.25 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 3 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.152865 +/- 0.003169
INFO:QMCTorch|  variance : 0.224055
INFO:QMCTorch|  epoch done in 0.22 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 4 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.155440 +/- 0.003123
INFO:QMCTorch|  variance : 0.220856
INFO:QMCTorch|  epoch done in 0.20 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 5 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.152281 +/- 0.003132
INFO:QMCTorch|  variance : 0.221490
INFO:QMCTorch|  epoch done in 0.19 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 6 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.155656 +/- 0.003057
INFO:QMCTorch|  variance : 0.216128
INFO:QMCTorch|  epoch done in 0.21 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 7 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.155033 +/- 0.003072
INFO:QMCTorch|  variance : 0.217255
INFO:QMCTorch|  epoch done in 0.21 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 8 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.156729 +/- 0.003055
INFO:QMCTorch|  variance : 0.216035
INFO:QMCTorch|  epoch done in 0.20 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 9 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.157059 +/- 0.003023
INFO:QMCTorch|  variance : 0.213726
INFO:QMCTorch|  epoch done in 0.21 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 10 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.155099 +/- 0.003046
INFO:QMCTorch|  variance : 0.215355
INFO:QMCTorch|  epoch done in 0.46 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 11 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.157807 +/- 0.002978
INFO:QMCTorch|  variance : 0.210545
INFO:QMCTorch|  epoch done in 0.19 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 12 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.155917 +/- 0.002920
INFO:QMCTorch|  variance : 0.206467
INFO:QMCTorch|  epoch done in 0.25 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 13 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.160233 +/- 0.002908
INFO:QMCTorch|  variance : 0.205608
INFO:QMCTorch|  epoch done in 0.20 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 14 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.155051 +/- 0.003021
INFO:QMCTorch|  variance : 0.213640
INFO:QMCTorch|  epoch done in 0.19 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 15 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.157552 +/- 0.002922
INFO:QMCTorch|  variance : 0.206606
INFO:QMCTorch|  epoch done in 0.20 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 16 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.150777 +/- 0.002986
INFO:QMCTorch|  variance : 0.211157
INFO:QMCTorch|  epoch done in 0.21 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 17 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.153752 +/- 0.002926
INFO:QMCTorch|  variance : 0.206869
INFO:QMCTorch|  epoch done in 0.18 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 18 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.154157 +/- 0.002893
INFO:QMCTorch|  variance : 0.204567
INFO:QMCTorch|  epoch done in 0.20 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 19 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.156157 +/- 0.002824
INFO:QMCTorch|  variance : 0.199705
INFO:QMCTorch|  epoch done in 0.43 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 20 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.161703 +/- 0.002866
INFO:QMCTorch|  variance : 0.202681
INFO:QMCTorch|  epoch done in 0.22 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 21 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.156807 +/- 0.002791
INFO:QMCTorch|  variance : 0.197351
INFO:QMCTorch|  epoch done in 0.22 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 22 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.156593 +/- 0.002774
INFO:QMCTorch|  variance : 0.196173
INFO:QMCTorch|  epoch done in 0.18 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 23 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.155829 +/- 0.002814
INFO:QMCTorch|  variance : 0.199004
INFO:QMCTorch|  epoch done in 0.20 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 24 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.158552 +/- 0.002720
INFO:QMCTorch|  variance : 0.192327
INFO:QMCTorch|  epoch done in 0.19 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 25 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.157268 +/- 0.002651
INFO:QMCTorch|  variance : 0.187444
INFO:QMCTorch|  epoch done in 0.20 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 26 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.160739 +/- 0.002627
INFO:QMCTorch|  variance : 0.185774
INFO:QMCTorch|  epoch done in 0.19 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 27 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.156840 +/- 0.002650
INFO:QMCTorch|  variance : 0.187409
INFO:QMCTorch|  epoch done in 0.19 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 28 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.160052 +/- 0.002668
INFO:QMCTorch|  variance : 0.188656
INFO:QMCTorch|  epoch done in 0.19 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 29 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.161738 +/- 0.002561
INFO:QMCTorch|  variance : 0.181082
INFO:QMCTorch|  epoch done in 0.19 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 30 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.163425 +/- 0.002620
INFO:QMCTorch|  variance : 0.185234
INFO:QMCTorch|  epoch done in 0.21 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 31 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.167101 +/- 0.002546
INFO:QMCTorch|  variance : 0.180017
INFO:QMCTorch|  epoch done in 0.19 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 32 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.157386 +/- 0.002624
INFO:QMCTorch|  variance : 0.185573
INFO:QMCTorch|  epoch done in 0.19 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 33 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.163439 +/- 0.002552
INFO:QMCTorch|  variance : 0.180488
INFO:QMCTorch|  epoch done in 0.20 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 34 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.161475 +/- 0.002532
INFO:QMCTorch|  variance : 0.179012
INFO:QMCTorch|  epoch done in 0.20 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 35 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.162207 +/- 0.002483
INFO:QMCTorch|  variance : 0.175559
INFO:QMCTorch|  epoch done in 0.21 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 36 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.160581 +/- 0.002549
INFO:QMCTorch|  variance : 0.180231
INFO:QMCTorch|  epoch done in 0.21 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 37 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.160479 +/- 0.002471
INFO:QMCTorch|  variance : 0.174733
INFO:QMCTorch|  epoch done in 0.19 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 38 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.162336 +/- 0.002502
INFO:QMCTorch|  variance : 0.176936
INFO:QMCTorch|  epoch done in 0.19 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 39 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.163931 +/- 0.002522
INFO:QMCTorch|  variance : 0.178305
INFO:QMCTorch|  epoch done in 0.19 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 40 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.162977 +/- 0.002429
INFO:QMCTorch|  variance : 0.171763
INFO:QMCTorch|  epoch done in 0.21 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 41 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.162803 +/- 0.002436
INFO:QMCTorch|  variance : 0.172285
INFO:QMCTorch|  epoch done in 0.18 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 42 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.162275 +/- 0.002436
INFO:QMCTorch|  variance : 0.172260
INFO:QMCTorch|  epoch done in 0.21 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 43 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.161893 +/- 0.002421
INFO:QMCTorch|  variance : 0.171202
INFO:QMCTorch|  epoch done in 0.21 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 44 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.161169 +/- 0.002388
INFO:QMCTorch|  variance : 0.168865
INFO:QMCTorch|  epoch done in 0.18 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 45 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.163101 +/- 0.002353
INFO:QMCTorch|  variance : 0.166373
INFO:QMCTorch|  epoch done in 1.37 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 46 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.161588 +/- 0.002367
INFO:QMCTorch|  variance : 0.167382
INFO:QMCTorch|  epoch done in 0.18 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 47 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.164853 +/- 0.002345
INFO:QMCTorch|  variance : 0.165804
INFO:QMCTorch|  epoch done in 0.22 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 48 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.167272 +/- 0.002287
INFO:QMCTorch|  variance : 0.161681
INFO:QMCTorch|  epoch done in 0.20 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 49 | 5000 sampling points
INFO:QMCTorch|  energy   : -1.157956 +/- 0.002319
INFO:QMCTorch|  variance : 0.164008
INFO:QMCTorch|  epoch done in 0.22 sec.
[15]:
plot_energy(obs.local_energy, e0=-1.1645, show_variance=True)
../_images/notebooks_wfopt_26_0.png
[ ]: