from time import time
import numpy as np
import torch
from scipy.interpolate import LinearNDInterpolator, RegularGridInterpolator
[docs]class InterpolateMolecularOrbitals:
def __init__(self, wf):
"""Interpolation of the AO using a log grid centered on each atom."""
self.wf = wf
def __call__(self, pos, method='irreg', orb='occupied', **kwargs):
if method == 'irreg':
n = kwargs['n'] if 'n' in kwargs else 6
return self.interpolate_mo_irreg_grid(pos, n=n, orb=orb)
elif method == 'reg':
rstr, bstr = 'resolution', 'border_length'
res = kwargs[rstr] if rstr in kwargs else 0.1
blength = kwargs[bstr] if bstr in kwargs else 2.
return self.interpolate_mo_reg_grid(pos, res, blength, orb)
[docs] def get_mo_max_index(self, orb):
"""Get the index of the highest MO to inlcude in the interpoaltion
Args:
orb (str): occupied or all
Raises:
ValueError: if orb not valid
"""
if orb == 'occupied':
self.mo_max_index = torch.stack(
self.wf.configs).max().item()+1
elif orb == 'all':
self.mo_max_index = self.wf.mol.basis.nmo+1
else:
raise ValueError(
'orb must occupied or all')
[docs] def interpolate_mo_irreg_grid(self, pos, n, orb):
"""Interpolate the mo occupied in the configs.
Args:
pos (torch.tensor): sampling points (Nbatch, 3*Nelec)
n (int, optional): Interpolation order. Defaults to 6.
Returns:
torch.tensor: mo values Nbatch, Nelec, Nmo
"""
self.get_mo_max_index(orb)
if not hasattr(self, 'interp_mo_func'):
grid_pts = get_log_grid(self.wf.mol.atom_coords, n=n)
def func(x):
x = torch.as_tensor(x).type(torch.get_default_dtype())
ao = self.wf.ao(x, one_elec=True)
mo = self.wf.mo(self.wf.mo_scf(ao)).squeeze(1)
return mo[:, :self.mo_max_index].detach()
self.interp_mo_func = interpolator_irreg_grid(
func, grid_pts)
nbatch = pos.shape[0]
mos = torch.zeros(nbatch, self.wf.mol.nelec,
self.wf.mol.basis.nmo)
mos[:, :, :self.mo_max_index] = interpolate_irreg_grid(
self.interp_mo_func, pos)
return mos
[docs] def interpolate_mo_reg_grid(self, pos, res, blength, orb):
"""Interpolate the mo occupied in the configs.
Args:
pos (torch.tensor): sampling points (Nbatch, 3*Nelec)
Returns:
torch.tensor: mo values Nbatch, Nelec, Nmo
"""
self.get_mo_max_index(orb)
if not hasattr(self, 'interp_mo_func'):
x, y, z = get_reg_grid(
self.wf.mol.atom_coords, resolution=res, border_length=blength)
def func(x):
x = torch.as_tensor(x).type(torch.get_default_dtype())
ao = self.wf.ao(x, one_elec=True)
mo = self.wf.mo(self.wf.mo_scf(ao)).squeeze(1)
return mo[:, :self.mo_max_index]
self.interp_mo_func = interpolator_reg_grid(
func, x, y, z)
nbatch = pos.shape[0]
mos = torch.zeros(nbatch, self.wf.mol.nelec,
self.wf.mol.basis.nmo)
mos[:, :, :self.mo_max_index] = interpolate_reg_grid(
self.interp_mo_func, pos)
return mos
[docs]class InterpolateAtomicOrbitals:
def __init__(self, wf):
"""Interpolation of the AO using a log grid centered on each atom."""
self.wf = wf
def __call__(self, pos, n=6, length=2):
"""Interpolate the AO.
Args:
pos (torch.tensor): positions of the walkers
n (int, optional): number of points on each log axis. Defaults to 6.
length (int, optional): half length of the grid. Defaults to 2.
Returns:
torch.tensor: Interpolated values
"""
if not hasattr(self, 'interp_func'):
t0 = time()
self.get_interpolator()
print('___', time()-t0)
t0 = time()
bas_coords = self.wf.ao.atom_coords.repeat_interleave(
self.wf.ao.nao_per_atom, dim=0) # <- we need the number of AO per atom not the number of BAS per atom
print('___bas ', time()-t0)
t0 = time()
xyz = (pos.view(-1, self.wf.ao.nelec, 1, self.wf.ao.ndim) -
bas_coords[None, ...]).detach().numpy()
print('___ xyz', time()-t0)
t0 = time()
data = np.array([self.interp_func[iorb](xyz[:, :, iorb, :])
for iorb in range(self.wf.ao.norb)])
print('___ data', time()-t0)
return torch.as_tensor(data.transpose(1, 2, 0))
[docs] def get_interpolator(self, n=6, length=2):
"""evaluate the interpolation function.
Args:
n (int, optional): number of points on each log axis. Defaults to 6.
length (int, optional): half length of the grid. Defaults to 2.
"""
xpts = logspace(n, length)
nxpts = len(xpts)
grid = np.stack(np.meshgrid(
xpts, xpts, xpts, indexing='ij')).T.reshape(-1, 3)[:, [2, 1, 0]]
def func(x):
x = torch.as_tensor(x).type(torch.get_default_dtype())
nbatch = x.shape[0]
xyz = x.view(-1, 1, 1, 3).expand(-1,
1, self.wf.ao.nbas, 3)
r = torch.sqrt((xyz**2).sum(3))
R = self.wf.ao.radial(r, self.wf.ao.bas_n,
self.wf.ao.bas_exp)
Y = self.wf.ao.harmonics(xyz)
bas = R * Y
bas = self.wf.ao.norm_cst * self.wf.ao.bas_coeffs * bas
ao = torch.zeros(nbatch, self.wf.ao.nelec,
self.wf.ao.norb, device=self.wf.ao.device)
ao.index_add_(2, self.wf.ao.index_ctr, bas)
return ao
data = func(grid).detach().numpy()
data = data.reshape(nxpts, nxpts, nxpts, -1)
self.interp_func = [RegularGridInterpolator((xpts, xpts, xpts),
data[..., i],
method='linear',
bounds_error=False,
fill_value=0.) for i in range(self.wf.ao.norb)]
[docs]def get_boundaries(atomic_positions, border_length=2.):
"""Computes the boundaries of the structure
Args:
atomic_positions (torch.Tensor, np.ndarray, list): atomic positions
border_length (float, optional): length of the border. Defaults to 2.
Raises:
ValueError: if type of positions not recognized
Returns:
(np.ndarray, np.ndarray, mp.ndarray): min, max values in the 3 cartesian directions
"""
if isinstance(atomic_positions, torch.Tensor):
pmin = atomic_positions.min(0)[0].detach().cpu().numpy()
pmax = atomic_positions.max(0)[0].detach().cpu().numpy()
elif isinstance(atomic_positions, np.ndarray):
pmin, pmax = atomic_positions.min(0), atomic_positions.max(0)
elif isinstance(atomic_positions, list):
_tmp = np.array(atomic_positions)
pmin, pmax = _tmp.min(0), _tmp.max(0)
else:
raise ValueError(
'atomic_positions must be either a torch.tensor, np.ndarray, or list')
pmin -= border_length
pmax += border_length
return pmin, pmax
[docs]def get_reg_grid(atomic_positions, resolution=0.1, border_length=2.):
"""Computes a regular grid points from the atomic positions
Args:
atomic_positions (torch.Tensor, np.ndarray, list): atomic positions
resolution (float, optional): ditance between two points. Defaults to 0.5.
border_length (float, optional): length of the border. Defaults to 2.
Returns:
(np.ndarray, np.ndarray, mp.ndarray): grid points in the x, y and z axis
"""
pmin, pmax = get_boundaries(
atomic_positions, border_length=border_length)
npts = np.ceil((pmax-pmin) / resolution).astype('int')
x = np.linspace(pmin[0], pmax[0], npts[0])
y = np.linspace(pmin[1], pmax[1], npts[1])
z = np.linspace(pmin[2], pmax[2], npts[2])
return (x, y, z)
[docs]def interpolator_reg_grid(func, x, y, z):
"""Computes the interpolation function
Args:
func (callable): compute the value of the funtion to interpolate
x (np.ndarray): grid points in the x direction
y (np.ndarray): grid points in the y direction
z (np.ndarray): grid points in the z direction
Returns:
callable: interpolation function
"""
nx, ny, nz = len(x), len(y), len(z)
grid = np.stack(np.meshgrid(
z, y, x, indexing='ij')).T.reshape(-1, 3)[:, [2, 1, 0]]
data = func(grid).detach().numpy()
data = data.reshape(nx, ny, nz, -1)
return RegularGridInterpolator((x, y, z),
data,
method='linear',
bounds_error=False,
fill_value=0.)
[docs]def interpolate_reg_grid(interpfunc, pos):
"""Interpolate the funtion
Args:
interpfunc (callable): function to interpolate the data points
pos (torch.tensor): positions of the walkers Nbatch x 3*Nelec
Returns:
torch.tensor: interpolated values of the function evaluated at pos
"""
nbatch = pos.shape[0]
nelec = pos.shape[1]//3
ndim = 3
data = interpfunc(pos.reshape(
nbatch, nelec, ndim).detach().numpy())
return torch.as_tensor(data)
[docs]def is_even(x):
"""return true if x is even."""
return x//2*2 == x
[docs]def logspace(n, length):
"""returns a 1d array of logspace between -length and +length."""
k = np.log(length+1)/np.log(10)
if is_even(n):
x = np.logspace(0.01, k, n//2)-1
return np.concatenate((-x[::-1], x[1:]))
else:
x = np.logspace(0.0, k, n//2+1)-1
return np.concatenate((-x[::-1], x[1:]))
[docs]def get_log_grid(atomic_positions, n=6, length=2., border_length=2.):
"""Computes a logarithmic grid
Args:
atomic_positions (list, np.ndarray, torch.tensor): positions of the atoms
n (int, optional): number of points in each axis around each atom. Defaults to 6.
length (float, optional): absolute value of the max distance from the atom. Defaults to 2.
border_length (float, optional): length of the border. Defaults to 2.
Returns:
np.ndanrray: grid points (Npts,3)
"""
x, y, z = np.stack(get_boundaries(
atomic_positions, border_length=border_length)).T
grid_pts = np.stack(np.meshgrid(
x, y, z, indexing='ij')).T.reshape(-1, 3)
x = logspace(n, length)
pts = np.stack(np.meshgrid(x, x, x, indexing='ij')
).T.reshape(-1, 3)
for pos in atomic_positions:
_tmp = pts + pos
if grid_pts is None:
grid_pts = _tmp
else:
grid_pts = np.concatenate((grid_pts, _tmp))
return grid_pts
[docs]def interpolator_irreg_grid(func, grid_pts):
"""compute a linear ND interpolator
Args:
func (callable): compute the value of the funtion to interpolate
grid_pts (np.ndarray): grid points in the x direction
Returns:
callable: interpolation function
"""
return LinearNDInterpolator(grid_pts,
func(grid_pts),
fill_value=0.)
[docs]def interpolate_irreg_grid(interpfunc, pos):
"""Interpolate the funtion
Args:
interpfunc (callable): function to interpolate the data points
pos (torch.tensor): positions of the walkers Nbatch x 3*Nelec
Returns:
torch.tensor: interpolated values of the function evaluated at pos
"""
nbatch, nelec, ndim = pos.shape[0], pos.shape[1]//3, 3
return torch.as_tensor(interpfunc(pos.reshape(nbatch, nelec, ndim)))