{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Single GPU Support\n", "\n", "> **Warning** \n", "> The use of GPU and mutli-GPU is under developpement and hasn't been thoroughly tested yet. Proceed with caution !\n", "\n", "Using pytorch as a backend, QMCTorch can leverage GPU cards available on your hardware.\n", "You of course must have the CUDA version of pytorch installed (https://pytorch.org/)\n", "\n", "\n", "Let's first import everything and create a molecule\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import optim\n", "from qmctorch.scf import Molecule\n", "from qmctorch.wavefunction import SlaterJastrow\n", "from qmctorch.sampler import Metropolis\n", "from qmctorch.utils import (plot_energy, plot_data)\n", "mol = Molecule(atom='H 0. 0. 0; H 0. 0. 1.', unit='bohr', redo_scf=True)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The use of GPU acceleration has been streamlined in QMCTorch, the only modification\n", "you need to do on your code is to specify `cuda=True` in the declaration of the wave function and sampler, this will automatically port all the necesaary tensors to the GPU and offload all the corresponding operation there." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " wf = SlaterJastrow(mol, cuda=True)\n", " sampler = Metropolis(nwalkers=100, nstep=500, step_size=0.25,\n", " nelec=wf.nelec, ndim=wf.ndim,\n", " init=mol.domain('atomic'),\n", " move={'type': 'all-elec', 'proba': 'normal'},\n", " cuda=True)\n", "else:\n", " print('CUDA not available, install torch with cuda support to proceed')" ] } ], "metadata": { "kernelspec": { "display_name": "qmctorch", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.0" } }, "nbformat": 4, "nbformat_minor": 2 }