Wave Function Optimization

We present here a complete example on how to use QMCTorch on a H2 molecule. We first need to import all the relevant modules :

[1]:
from torch import optim
from qmctorch.scf import Molecule
from qmctorch.wavefunction import SlaterJastrow
from qmctorch.solver import Solver
from qmctorch.sampler import Metropolis
from qmctorch.utils import set_torch_double_precision
from qmctorch.utils import (plot_energy, plot_data)
set_torch_double_precision()
INFO:QMCTorch|  ____    __  ______________             _
INFO:QMCTorch| / __ \  /  |/  / ___/_  __/__  ________/ /
INFO:QMCTorch|/ /_/ / / /|_/ / /__  / / / _ \/ __/ __/ _ \
INFO:QMCTorch|\___\_\/_/  /_/\___/ /_/  \___/_/  \__/_//_/

Creating the system

The first step is to define a molecule. We here use a H2 molecule with both hydrgen atoms on the z-axis and separated by 1.38 atomic unit. We choose here to use Slater orbitals that can be otained via ADF. We simply here reload calculations to create the molecule

[2]:
mol = Molecule(load='./hdf5/H2_adf_dzp.hdf5')
INFO:QMCTorch|
INFO:QMCTorch| SCF Calculation
INFO:QMCTorch|  Loading data from ./hdf5/H2_adf_dzp.hdf5

We then define the wave function relative to this molecule. We also specify here the determinants we want to use in the CI expansion. We use here a to include all the single and double excitation with 2 electrons and 2 orbitals

[3]:
wf = SlaterJastrow(mol, configs='single_double(2,2)')
INFO:QMCTorch|
INFO:QMCTorch| Wave Function
INFO:QMCTorch|  Jastrow factor      : True
INFO:QMCTorch|  Jastrow kernel      : PadeJastrowKernel
INFO:QMCTorch|  Highest MO included : 10
INFO:QMCTorch|  Configurations      : single_double(2,2)
INFO:QMCTorch|  Number of confs     : 4
INFO:QMCTorch|  Kinetic energy      : jacobi
INFO:QMCTorch|  Number var  param   : 121
INFO:QMCTorch|  Cuda support        : False

As a sampler we use a simple Metropolis Hasting with 1000 walkers. The walkers are initially localized around the atoms. Each walker will perform 2000 steps of size 0.2 atomic unit and will only keep the last position of each walker (ntherm=-1). During each move all the the electrons are moved simultaneously within a normal distribution centered around their current location.

[4]:
sampler = Metropolis(nwalkers=5000,
                     nstep=200, step_size=0.2,
                     ntherm=-1, ndecor=100,
                     nelec=wf.nelec, init=mol.domain('atomic'),
                     move={'type': 'all-elec', 'proba': 'normal'})
INFO:QMCTorch|
INFO:QMCTorch| Monte-Carlo Sampler
INFO:QMCTorch|  Number of walkers   : 5000
INFO:QMCTorch|  Number of steps     : 200
INFO:QMCTorch|  Step size           : 0.2
INFO:QMCTorch|  Thermalization steps: -1
INFO:QMCTorch|  Decorelation steps  : 100
INFO:QMCTorch|  Walkers init pos    : atomic
INFO:QMCTorch|  Move type           : all-elec
INFO:QMCTorch|  Move proba          : normal

We will use the ADAM optimizer implemented in pytorch with custom learning rate for each layer. We also define a linear scheduler that will decrease the learning rate after 100 steps

[5]:
lr_dict = [{'params': wf.jastrow.parameters(), 'lr': 1E-2},
           {'params': wf.ao.parameters(), 'lr': 1E-6},
           {'params': wf.mo.parameters(), 'lr': 2E-3},
           {'params': wf.fc.parameters(), 'lr': 2E-3}]
opt = optim.Adam(lr_dict, lr=1E-3)

A scheduler can also be used to progressively decrease the value of the learning rate during the optimization.

[6]:
scheduler = optim.lr_scheduler.StepLR(opt, step_size=100, gamma=0.90)

We can now assemble the solver

[7]:
solver = Solver(wf=wf, sampler=sampler, optimizer=opt, scheduler=None)
INFO:QMCTorch|
INFO:QMCTorch| Warning : dump to hdf5
INFO:QMCTorch| Object Solver already exists in H2_adf_dzp_QMCTorch.hdf5
INFO:QMCTorch| Object name changed to SolverSlaterJastrow_7
INFO:QMCTorch|
INFO:QMCTorch|
INFO:QMCTorch| QMC Solver
INFO:QMCTorch|  WaveFunction        : SlaterJastrow
INFO:QMCTorch|  Sampler             : Metropolis
INFO:QMCTorch|  Optimizer           : Adam

Comfiguring the solver

Many parameters of the optimization can be controlled. We can specify which observale to track during the optimization. Here only the local energies will be recorded but one can also record the variational parameters

[8]:
solver.configure(track=['local_energy', 'parameters'])

Some variational parameters can be frozen and therefore not optimized. We here freeze the MO coefficients and the AO parameters and therefore only the jastrow parametres and the CI coefficients will be optmized

[9]:
solver.configure(freeze=['ao'])

Either the mean or the variance of local energies can be used as a loss function. We choose here to minimize the energy to optimize the wave function

[10]:
solver.configure(loss='energy')

The gradients of the wave function w.r.t. the variational parameters can be computed directly via automatic differntiation (grad='auto')or manually (grad='auto') via a reduced noise formula. We pick here a manual calculation

[11]:
solver.configure(grad='manual')

We also configure the resampling so that the positions of the walkers are updated by performing 25 MC steps from their previous positions after each optimization step.

[12]:
solver.configure(resampling={'mode': 'update',
                            'resample_every': 1,
                            'nstep_update': 25})

Running the wave function optimization

We can now run the optimization. We use here 250 optimization steps (epoch), using all the points in a single mini-batch.

[13]:
obs = solver.run(50)
INFO:QMCTorch|
INFO:QMCTorch|  Optimization
INFO:QMCTorch|  Task                :
INFO:QMCTorch|  Number Parameters   : 115
INFO:QMCTorch|  Number of epoch     : 50
INFO:QMCTorch|  Batch size          : 5000
INFO:QMCTorch|  Loss function       : energy
INFO:QMCTorch|  Clip Loss           : False
INFO:QMCTorch|  Gradients           : manual
INFO:QMCTorch|  Resampling mode     : update
INFO:QMCTorch|  Resampling every    : 1
INFO:QMCTorch|  Resampling steps    : 25
INFO:QMCTorch|  Output file         : H2_adf_dzp_QMCTorch.hdf5
INFO:QMCTorch|  Checkpoint every    : None
INFO:QMCTorch|
INFO:QMCTorch|
INFO:QMCTorch|  epoch 0
INFO:QMCTorch|  energy   : -1.155363 +/- 0.003267
INFO:QMCTorch|  variance : 0.231010
INFO:QMCTorch|  epoch done in 0.49 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 1
INFO:QMCTorch|  energy   : -1.149161 +/- 0.003279
INFO:QMCTorch|  variance : 0.231844
INFO:QMCTorch|  epoch done in 0.59 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 2
INFO:QMCTorch|  energy   : -1.150710 +/- 0.003106
INFO:QMCTorch|  variance : 0.219625
INFO:QMCTorch|  epoch done in 0.94 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 3
INFO:QMCTorch|  energy   : -1.156548 +/- 0.003170
INFO:QMCTorch|  variance : 0.224149
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 4
INFO:QMCTorch|  energy   : -1.155115 +/- 0.003221
INFO:QMCTorch|  variance : 0.227777
INFO:QMCTorch|  epoch done in 0.51 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 5
INFO:QMCTorch|  energy   : -1.156112 +/- 0.003083
INFO:QMCTorch|  variance : 0.217972
INFO:QMCTorch|  epoch done in 0.51 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 6
INFO:QMCTorch|  energy   : -1.155542 +/- 0.003070
INFO:QMCTorch|  variance : 0.217062
INFO:QMCTorch|  epoch done in 0.94 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 7
INFO:QMCTorch|  energy   : -1.157297 +/- 0.003046
INFO:QMCTorch|  variance : 0.215387
INFO:QMCTorch|  epoch done in 0.48 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 8
INFO:QMCTorch|  energy   : -1.150183 +/- 0.003147
INFO:QMCTorch|  variance : 0.222538
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 9
INFO:QMCTorch|  energy   : -1.155700 +/- 0.003062
INFO:QMCTorch|  variance : 0.216530
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 10
INFO:QMCTorch|  energy   : -1.154875 +/- 0.003005
INFO:QMCTorch|  variance : 0.212476
INFO:QMCTorch|  epoch done in 0.60 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 11
INFO:QMCTorch|  energy   : -1.154984 +/- 0.003024
INFO:QMCTorch|  variance : 0.213820
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 12
INFO:QMCTorch|  energy   : -1.154497 +/- 0.002974
INFO:QMCTorch|  variance : 0.210262
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 13
INFO:QMCTorch|  energy   : -1.157227 +/- 0.003000
INFO:QMCTorch|  variance : 0.212123
INFO:QMCTorch|  epoch done in 0.57 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 14
INFO:QMCTorch|  energy   : -1.156778 +/- 0.002914
INFO:QMCTorch|  variance : 0.206054
INFO:QMCTorch|  epoch done in 0.75 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 15
INFO:QMCTorch|  energy   : -1.152052 +/- 0.003022
INFO:QMCTorch|  variance : 0.213717
INFO:QMCTorch|  epoch done in 0.49 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 16
INFO:QMCTorch|  energy   : -1.158149 +/- 0.002847
INFO:QMCTorch|  variance : 0.201333
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 17
INFO:QMCTorch|  energy   : -1.158337 +/- 0.002852
INFO:QMCTorch|  variance : 0.201654
INFO:QMCTorch|  epoch done in 0.48 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 18
INFO:QMCTorch|  energy   : -1.158138 +/- 0.002793
INFO:QMCTorch|  variance : 0.197500
INFO:QMCTorch|  epoch done in 0.89 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 19
INFO:QMCTorch|  energy   : -1.157327 +/- 0.002869
INFO:QMCTorch|  variance : 0.202897
INFO:QMCTorch|  epoch done in 0.99 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 20
INFO:QMCTorch|  energy   : -1.155671 +/- 0.002901
INFO:QMCTorch|  variance : 0.205139
INFO:QMCTorch|  epoch done in 0.52 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 21
INFO:QMCTorch|  energy   : -1.156606 +/- 0.002863
INFO:QMCTorch|  variance : 0.202470
INFO:QMCTorch|  epoch done in 0.48 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 22
INFO:QMCTorch|  energy   : -1.164993 +/- 0.002852
INFO:QMCTorch|  variance : 0.201661
INFO:QMCTorch|  epoch done in 0.51 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 23
INFO:QMCTorch|  energy   : -1.157040 +/- 0.002765
INFO:QMCTorch|  variance : 0.195510
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 24
INFO:QMCTorch|  energy   : -1.163667 +/- 0.002707
INFO:QMCTorch|  variance : 0.191386
INFO:QMCTorch|  epoch done in 0.57 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 25
INFO:QMCTorch|  energy   : -1.159113 +/- 0.002700
INFO:QMCTorch|  variance : 0.190943
INFO:QMCTorch|  epoch done in 0.51 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 26
INFO:QMCTorch|  energy   : -1.162071 +/- 0.002661
INFO:QMCTorch|  variance : 0.188190
INFO:QMCTorch|  epoch done in 0.53 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 27
INFO:QMCTorch|  energy   : -1.158837 +/- 0.002642
INFO:QMCTorch|  variance : 0.186836
INFO:QMCTorch|  epoch done in 0.49 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 28
INFO:QMCTorch|  energy   : -1.155956 +/- 0.002649
INFO:QMCTorch|  variance : 0.187284
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 29
INFO:QMCTorch|  energy   : -1.162127 +/- 0.002609
INFO:QMCTorch|  variance : 0.184491
INFO:QMCTorch|  epoch done in 0.73 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 30
INFO:QMCTorch|  energy   : -1.163752 +/- 0.002560
INFO:QMCTorch|  variance : 0.181025
INFO:QMCTorch|  epoch done in 0.52 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 31
INFO:QMCTorch|  energy   : -1.159163 +/- 0.002590
INFO:QMCTorch|  variance : 0.183165
INFO:QMCTorch|  epoch done in 0.56 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 32
INFO:QMCTorch|  energy   : -1.163472 +/- 0.002603
INFO:QMCTorch|  variance : 0.184072
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 33
INFO:QMCTorch|  energy   : -1.165384 +/- 0.002563
INFO:QMCTorch|  variance : 0.181214
INFO:QMCTorch|  epoch done in 0.51 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 34
INFO:QMCTorch|  energy   : -1.163774 +/- 0.002527
INFO:QMCTorch|  variance : 0.178661
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 35
INFO:QMCTorch|  energy   : -1.161995 +/- 0.002472
INFO:QMCTorch|  variance : 0.174763
INFO:QMCTorch|  epoch done in 0.51 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 36
INFO:QMCTorch|  energy   : -1.161698 +/- 0.002521
INFO:QMCTorch|  variance : 0.178254
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 37
INFO:QMCTorch|  energy   : -1.162856 +/- 0.002532
INFO:QMCTorch|  variance : 0.179051
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 38
INFO:QMCTorch|  energy   : -1.157138 +/- 0.002535
INFO:QMCTorch|  variance : 0.179220
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 39
INFO:QMCTorch|  energy   : -1.163320 +/- 0.002536
INFO:QMCTorch|  variance : 0.179332
INFO:QMCTorch|  epoch done in 0.74 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 40
INFO:QMCTorch|  energy   : -1.161880 +/- 0.002464
INFO:QMCTorch|  variance : 0.174239
INFO:QMCTorch|  epoch done in 0.48 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 41
INFO:QMCTorch|  energy   : -1.158324 +/- 0.002542
INFO:QMCTorch|  variance : 0.179777
INFO:QMCTorch|  epoch done in 0.51 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 42
INFO:QMCTorch|  energy   : -1.158298 +/- 0.002442
INFO:QMCTorch|  variance : 0.172696
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 43
INFO:QMCTorch|  energy   : -1.160970 +/- 0.002371
INFO:QMCTorch|  variance : 0.167662
INFO:QMCTorch|  epoch done in 0.79 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 44
INFO:QMCTorch|  energy   : -1.159741 +/- 0.002362
INFO:QMCTorch|  variance : 0.166993
INFO:QMCTorch|  epoch done in 0.51 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 45
INFO:QMCTorch|  energy   : -1.162254 +/- 0.002349
INFO:QMCTorch|  variance : 0.166119
INFO:QMCTorch|  epoch done in 0.73 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 46
INFO:QMCTorch|  energy   : -1.160540 +/- 0.002314
INFO:QMCTorch|  variance : 0.163611
INFO:QMCTorch|  epoch done in 0.49 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 47
INFO:QMCTorch|  energy   : -1.162938 +/- 0.002316
INFO:QMCTorch|  variance : 0.163749
INFO:QMCTorch|  epoch done in 0.49 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 48
INFO:QMCTorch|  energy   : -1.163674 +/- 0.002214
INFO:QMCTorch|  variance : 0.156522
INFO:QMCTorch|  epoch done in 0.51 sec.
INFO:QMCTorch|
INFO:QMCTorch|  epoch 49
INFO:QMCTorch|  energy   : -1.163112 +/- 0.002278
INFO:QMCTorch|  variance : 0.161065
INFO:QMCTorch|  epoch done in 0.50 sec.
INFO:QMCTorch|
INFO:QMCTorch| Warning : dump to hdf5
INFO:QMCTorch| Object wf_opt already exists in H2_adf_dzp_QMCTorch.hdf5
INFO:QMCTorch| Object name changed to wf_opt_7
INFO:QMCTorch|
[14]:
plot_energy(obs.local_energy, e0=-1.1645, show_variance=True)
../_images/notebooks_wfopt_26_0.png
[ ]: