Source code for qmctorch.utils.hdf5_utils

from types import SimpleNamespace

import h5py
import numpy as np
import torch

from .. import log














[docs] def load_from_hdf5(obj, fname, obj_name): """Load the content of an hdf5 file in an object. Arguments: obj {object} -- object where to load the data fname {str} -- name pf the hdf5 file obj_name {str} -- name of the root group in the hdf5 """ h5 = h5py.File(fname, "r") root_grp = h5[obj_name] load_object(root_grp, obj, obj_name) h5.close()
[docs] def load_object(grp, parent_obj, grp_name): """Load object attribute from the hdf5 group/data Arguments: grp {hdf5 group} -- the current group in the hdf5 architecture parent_obj {object} -- parent object grp_name {str} -- name of the group """ for child_grp_name, child_grp in grp.items(): if isgroup(child_grp): load_group(child_grp, parent_obj, child_grp_name) else: load_data(child_grp, parent_obj, child_grp_name)
[docs] def load_group(grp, parent_obj, grp_name): """Load object attribute from the hdf5 group Arguments: grp {hdf5 group} -- the current group in the hdf5 architecture parent_obj {object} -- parent object grp_name {str} -- name of the group """ try: if not hasattr(parent_obj, grp_name): parent_obj.__setattr__(grp_name, SimpleNamespace()) load_object(grp, parent_obj.__getattribute__(grp_name), grp_name) except Exception as expt_message: print(expt_message) print_load_error(grp_name)
[docs] def load_data(grp, parent_obj, grp_name): """Load data from the hdf5 data Arguments: grp {hdf5 group} -- the current group in the hdf5 architecture parent_obj {object} -- parent object grp_name {str} -- name of the group """ try: parent_obj.__setattr__(grp_name, cast_loaded_data(grp[()])) except: print_load_error(grp_name)
[docs] def cast_loaded_data(data): """cast the data before loading.""" cast_fn = {bytes: bytes2str} if type(data) in cast_fn: data = cast_fn[type(data)](data) return data
[docs] def bytes2str(bstr): """Convert a bytes into string.""" if type(bstr) is bytes: return bstr.decode("utf-8") elif type(bstr) is str: return bstr else: raise TypeError( bstr, " should be a bytes or str but got ", type(bstr), " instead" )
[docs] def lookup_cast(ori_type, current_type): raise NotImplementedError("cast the data to the type contained in .attrs['type']")
[docs] def isgroup(grp): """Check if current hdf5 group is a group Arguments: grp {hdf5 group} -- hdf5 group or dataset Returns: bool -- True if the group is a group """ return type(grp) == h5py._hl.group.Group or type(grp) == h5py._hl.files.File
[docs] def dump_to_hdf5(obj, fname, root_name=None): """Dump the content of an object in a hdf5 file. Arguments: obj {object} -- object to dump fname {str} -- name of the hdf5 Keyword Arguments: root_name {str} -- root group in the hdf5 file (default: {None}) """ h5 = h5py.File(fname, "a") if root_name is None: root_name = obj.__class__.__name__ # change root name if that name is already present in the file if root_name in h5: log.info("") log.info(" Warning : dump to hdf5") log.info( " Object {obj} already exists in {parent}", obj=root_name, parent=fname ) n = sum(1 for n in h5 if n.startswith(root_name)) + 1 root_name = root_name + "_" + str(n) log.info(" Object name changed to {obj}", obj=root_name) log.info("") insert_object(obj, h5, root_name) h5.close() return root_name
[docs] def insert_object(obj, parent_grp, obj_name): """Insert the content of the object in the hdf5 file Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ if haschildren(obj): insert_group(obj, parent_grp, obj_name) else: insert_data(obj, parent_grp, obj_name)
[docs] def insert_group(obj, parent_grp, obj_name): """Insert the content of the object in a hdf5 group Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ # ignore object starting with underscore # a lot of pytorch internal are like that if obj_name.startswith("_"): log.debug( " Warning : Object {obj} not stored in {parent}", obj=obj_name, parent=parent_grp, ) log.debug(' : because object name starts with "_"') return # store if the object name is not in parent if obj_name not in parent_grp: try: own_grp = parent_grp.create_group(obj_name) for child_name in get_children_names(obj): child_obj = get_child_object(obj, child_name) insert_object(child_obj, own_grp, child_name) except Exception as inst: print(type(inst)) print(inst) print_insert_error(obj, obj_name) # if something went wrong anyway else: log.critical( " Warning : Object {obj} already exists in {parent}", obj=obj_name, parent=parent_grp, ) log.critical(" Warning : Keeping original version of the data")
[docs] def insert_data(obj, parent_grp, obj_name): """Insert the content of the object in a hdf5 dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ if obj_name.startswith("_"): return try: lookup_insert = { list: insert_list, tuple: insert_tuple, np.ndarray: insert_numpy, torch.Tensor: insert_torch_tensor, torch.nn.parameter.Parameter: insert_torch_parameter, torch.device: insert_none, type(None): insert_none, } insert_fn = lookup_insert[type(obj)] except KeyError: insert_fn = insert_default try: insert_fn(obj, parent_grp, obj_name) # insert_type(obj, parent_grp, obj_name) except Exception as expt_message: print(expt_message) print_insert_error(obj, obj_name)
[docs] def insert_type(obj, parent_grp, obj_name): """Insert the content of the type object in an attribute Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ try: parent_grp[obj_name].attrs["type"] = str(type(obj)) except: print_insert_type_error(obj, obj_name)
[docs] def insert_default(obj, parent_grp, obj_name): """Default funtion to insert a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ try: parent_grp.create_dataset(obj_name, data=obj) except Exception as expt_message: print(expt_message) print_insert_error(obj, obj_name)
[docs] def insert_list(obj, parent_grp, obj_name): """funtion to insert a list as a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ try: if np.all([isinstance(el, torch.Tensor) for el in obj]): obj = [el.numpy() for el in obj] parent_grp.create_dataset(obj_name, data=np.array(obj)) except: for il, l in enumerate(obj): try: insert_object(l, parent_grp, obj_name + "_" + str(il)) except: print_insert_error(obj, obj_name)
[docs] def insert_tuple(obj, parent_grp, obj_name): """funtion to insert a tuple as a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ # fix for type torch.Tensor obj = [o.cpu().numpy() if isinstance(o, torch.Tensor) else o for o in obj] insert_list(list(obj), parent_grp, obj_name)
[docs] def insert_numpy(obj, parent_grp, obj_name): """funtion to insert a numpy array as a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ if obj.dtype.str.startswith("<U"): obj = obj.astype("S") insert_default(obj, parent_grp, obj_name)
[docs] def insert_torch_tensor(obj, parent_grp, obj_name): """funtion to insert a torch tensor as a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ insert_numpy(obj.cpu().detach().numpy(), parent_grp, obj_name)
[docs] def insert_torch_parameter(obj, parent_grp, obj_name): """funtion to insert a torch parameter as a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ insert_torch_tensor(obj.data.cpu(), parent_grp, obj_name)
[docs] def insert_none(obj, parent_grp, obj_name): """funtion to insert a None Type as a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ return
[docs] def haschildren(obj): """Check if the object has children Arguments: obj {object} -- the object to check Returns: bool -- True if the object has children """ ommit_type = [torch.nn.parameter.Parameter, torch.Tensor] if type(obj) in ommit_type: return False else: return hasattr(obj, "__dict__") or hasattr(obj, "keys")
[docs] def children(obj): """Returns the children of the object as items Arguments: obj {object} -- the object to check Returns: dict -- items """ if hasattr(obj, "__dict__"): return obj.__dict__.items() elif hasattr(obj, "keys"): return obj.items()
[docs] def get_children_names(obj): """Returns the children names of the object as items Arguments: obj {object} -- the object to check Returns: dict -- items """ if hasattr(obj, "__dict__"): names = list(obj.__dict__.keys()) elif hasattr(obj, "keys"): names = list(obj.keys()) if hasattr(obj, "__extra_attr__"): names += obj.__extra_attr__ if hasattr(obj, "state_dict"): names += list(obj.state_dict().keys()) return list(set(names))
[docs] def get_child_object(obj, child_name): """Return the child object Arguments: obj {object} -- parent object child_name {str} -- cild name Returns: object -- child object """ if hasattr(obj, "__getattr__"): try: return obj.__getattr__(child_name) except AttributeError: pass if hasattr(obj, "__getattribute__"): try: return obj.__getattribute__(child_name) except AttributeError: pass if hasattr(obj, "__getitem__"): try: return obj.__getitem__(child_name) except AttributeError: pass
[docs] def add_group_attr(filename, grp_name, attr): """Add attribute to a given group Arguments: filename {str} -- name of the file grp_name {str} -- name of the group attr {dict} -- attrivutes to add """ h5 = h5py.File(filename, "a") for k, v in attr.items(): h5[grp_name].attrs[k] = v h5.close()
[docs] def register_extra_attributes(obj, attr_names): """Register extra attribute to be able to dump them Arguments: obj {object} -- the object where we want to add attr attr_names {list} -- a list of attr names """ if not hasattr(obj, "__extra_attr__"): obj.__extra_attr__ = [] obj.__extra_attr__ += attr_names