Source code for qmctorch.utils.hdf5_utils

from types import SimpleNamespace

import h5py
import numpy as np
import torch

from .. import log






    



[docs]def load_from_hdf5(obj, fname, obj_name): """Load the content of an hdf5 file in an object. Arguments: obj {object} -- object where to load the data fname {str} -- name pf the hdf5 file obj_name {str} -- name of the root group in the hdf5 """ h5 = h5py.File(fname, 'r') root_grp = h5[obj_name] load_object(root_grp, obj, obj_name) h5.close()
[docs]def load_object(grp, parent_obj, grp_name): """Load object attribute from the hdf5 group/data Arguments: grp {hdf5 group} -- the current group in the hdf5 architecture parent_obj {object} -- parent object grp_name {str} -- name of the group """ for child_grp_name, child_grp in grp.items(): if isgroup(child_grp): load_group(child_grp, parent_obj, child_grp_name) else: load_data(child_grp, parent_obj, child_grp_name)
[docs]def load_group(grp, parent_obj, grp_name): """Load object attribute from the hdf5 group Arguments: grp {hdf5 group} -- the current group in the hdf5 architecture parent_obj {object} -- parent object grp_name {str} -- name of the group """ try: if not hasattr(parent_obj, grp_name): parent_obj.__setattr__( grp_name, SimpleNamespace()) load_object(grp, parent_obj.__getattribute__( grp_name), grp_name) except: print_load_error(grp_name)
[docs]def load_data(grp, parent_obj, grp_name): """Load data from the hdf5 data Arguments: grp {hdf5 group} -- the current group in the hdf5 architecture parent_obj {object} -- parent object grp_name {str} -- name of the group """ try: parent_obj.__setattr__(grp_name, cast_loaded_data(grp[()])) except: print_load_error(grp_name)
[docs]def cast_loaded_data(data): """cast the data before loading.""" cast_fn = {bytes: bytes2str} if type(data) in cast_fn: data = cast_fn[type(data)](data) return data
[docs]def bytes2str(bstr): """Convert a bytes into string.""" if type(bstr) is bytes: return bstr.decode('utf-8') elif type(bstr) is str: return bstr else: raise TypeError( bstr, ' should be a bytes or str but got ', type(bstr), ' instead')
[docs]def lookup_cast(ori_type, current_type): raise NotImplementedError( "cast the data to the type contained in .attrs['type']")
[docs]def isgroup(grp): """Check if current hdf5 group is a group Arguments: grp {hdf5 group} -- hdf5 group or dataset Returns: bool -- True if the group is a group """ return type(grp) == h5py._hl.group.Group or type(grp) == h5py._hl.files.File
[docs]def dump_to_hdf5(obj, fname, root_name=None): """Dump the content of an object in a hdf5 file. Arguments: obj {object} -- object to dump fname {str} -- name of the hdf5 Keyword Arguments: root_name {str} -- root group in the hdf5 file (default: {None}) """ h5 = h5py.File(fname, 'a') if root_name is None: root_name = obj.__class__.__name__ # change root name if that name is already present in the file if root_name in h5: log.info('') log.info(' Warning : dump to hdf5') log.info( ' Object {obj} already exists in {parent}', obj=root_name, parent=fname) n = sum(1 for n in h5 if n.startswith(root_name)) + 1 root_name = root_name + '_' + str(n) log.info( ' Object name changed to {obj}', obj=root_name) log.info('') insert_object(obj, h5, root_name) h5.close() return root_name
[docs]def insert_object(obj, parent_grp, obj_name): """Insert the content of the object in the hdf5 file Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ if haschildren(obj): insert_group(obj, parent_grp, obj_name) else: insert_data(obj, parent_grp, obj_name)
[docs]def insert_group(obj, parent_grp, obj_name): """Insert the content of the object in a hdf5 group Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ # ignore object starting with underscore # a lot of pytorch internal are like that if obj_name.startswith('_'): log.debug( ' Warning : Object {obj} not stored in {parent}', obj=obj_name, parent=parent_grp) log.debug( ' : because object name starts with "_"') return # store if the object name is not in parent if obj_name not in parent_grp: try: own_grp = parent_grp.create_group(obj_name) for child_name in get_children_names(obj): child_obj = get_child_object(obj, child_name) insert_object(child_obj, own_grp, child_name) except Exception as inst: print(type(inst)) print(inst) print_insert_error(obj, obj_name) # if something went wrong anyway else: log.critical( ' Warning : Object {obj} already exists in {parent}', obj=obj_name, parent=parent_grp) log.critical( ' Warning : Keeping original version of the data')
[docs]def insert_data(obj, parent_grp, obj_name): """Insert the content of the object in a hdf5 dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ if obj_name.startswith('_'): return try: lookup_insert = {list: insert_list, tuple: insert_tuple, np.ndarray: insert_numpy, torch.Tensor: insert_torch_tensor, torch.nn.parameter.Parameter: insert_torch_parameter, torch.device: insert_none, type(None): insert_none} insert_fn = lookup_insert[type(obj)] except KeyError: insert_fn = insert_default try: insert_fn(obj, parent_grp, obj_name) # insert_type(obj, parent_grp, obj_name) except Exception as expt_message: print(expt_message) print_insert_error(obj, obj_name)
[docs]def insert_type(obj, parent_grp, obj_name): """Insert the content of the type object in an attribute Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ try: parent_grp[obj_name].attrs['type'] = str(type(obj)) except: print_insert_type_error(obj, obj_name)
[docs]def insert_default(obj, parent_grp, obj_name): """Default funtion to insert a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ try: parent_grp.create_dataset(obj_name, data=obj) except Exception as expt_message: print(expt_message) print_insert_error(obj, obj_name)
[docs]def insert_list(obj, parent_grp, obj_name): """funtion to insert a list as a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ try: if np.all([isinstance(el,torch.Tensor) for el in obj]): obj = [el.numpy() for el in obj] parent_grp.create_dataset(obj_name, data=np.array(obj)) except: for il, l in enumerate(obj): try: insert_object(l, parent_grp, obj_name+'_'+str(il)) except: print_insert_error(obj, obj_name)
[docs]def insert_tuple(obj, parent_grp, obj_name): """funtion to insert a tuple as a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ insert_list(list(obj), parent_grp, obj_name)
[docs]def insert_numpy(obj, parent_grp, obj_name): """funtion to insert a numpy array as a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ if obj.dtype.str.startswith('<U'): obj = obj.astype('S') insert_default(obj, parent_grp, obj_name)
[docs]def insert_torch_tensor(obj, parent_grp, obj_name): """funtion to insert a torch tensor as a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ insert_numpy(obj.cpu().detach().numpy(), parent_grp, obj_name)
[docs]def insert_torch_parameter(obj, parent_grp, obj_name): """funtion to insert a torch parameter as a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ insert_torch_tensor(obj.data.cpu(), parent_grp, obj_name)
[docs]def insert_none(obj, parent_grp, obj_name): """funtion to insert a None Type as a dataset Arguments: obj {object} -- object to save parent_grp {hdf5 group} -- group where to dump obj_name {str} -- name of the object """ return
[docs]def haschildren(obj): """Check if the object has children Arguments: obj {object} -- the object to check Returns: bool -- True if the object has children """ ommit_type = [torch.nn.parameter.Parameter, torch.Tensor] if type(obj) in ommit_type: return False else: return hasattr(obj, '__dict__') or hasattr(obj, 'keys')
[docs]def children(obj): """Returns the children of the object as items Arguments: obj {object} -- the object to check Returns: dict -- items """ if hasattr(obj, '__dict__'): return obj.__dict__.items() elif hasattr(obj, 'keys'): return obj.items()
[docs]def get_children_names(obj): """Returns the children names of the object as items Arguments: obj {object} -- the object to check Returns: dict -- items """ if hasattr(obj, '__dict__'): names = list(obj.__dict__.keys()) elif hasattr(obj, 'keys'): names = list(obj.keys()) if hasattr(obj, '__extra_attr__'): names += obj.__extra_attr__ if hasattr(obj, 'state_dict'): names += list(obj.state_dict().keys()) return list(set(names))
[docs]def get_child_object(obj, child_name): """Return the child object Arguments: obj {object} -- parent object child_name {str} -- cild name Returns: object -- child object """ if hasattr(obj, '__getattr__'): try: return obj.__getattr__(child_name) except AttributeError: pass if hasattr(obj, '__getattribute__'): try: return obj.__getattribute__(child_name) except AttributeError: pass if hasattr(obj, '__getitem__'): try: return obj.__getitem__(child_name) except AttributeError: pass
[docs]def add_group_attr(filename, grp_name, attr): """Add attribute to a given group Arguments: filename {str} -- name of the file grp_name {str} -- name of the group attr {dict} -- attrivutes to add """ h5 = h5py.File(filename, 'a') for k, v in attr.items(): h5[grp_name].attrs[k] = v h5.close()
[docs]def register_extra_attributes(obj, attr_names): """Register extra attribute to be able to dump them Arguments: obj {object} -- the object where we want to add attr attr_names {list} -- a list of attr names """ if not hasattr(obj, '__extra_attr__'): obj.__extra_attr__ = [] obj.__extra_attr__ += attr_names