from types import SimpleNamespace
import h5py
import numpy as np
import torch
from .. import log
[docs]def print_insert_error(obj, obj_name):
print(obj_name, obj)
log.critical('Issue inserting data {0} of type {type}',
obj_name, type=str(type(obj)))
[docs]def print_insert_type_error(obj, obj_name):
log.critical('Issue inserting type of data {0}} ({type}})' %
obj_name, type=str(type(obj)))
[docs]def print_load_error(grp):
log.critical('Issue loading {grp}', grp=grp)
[docs]def load_from_hdf5(obj, fname, obj_name):
"""Load the content of an hdf5 file in an object.
Arguments:
obj {object} -- object where to load the data
fname {str} -- name pf the hdf5 file
obj_name {str} -- name of the root group in the hdf5
"""
h5 = h5py.File(fname, 'r')
root_grp = h5[obj_name]
load_object(root_grp, obj, obj_name)
h5.close()
[docs]def load_object(grp, parent_obj, grp_name):
"""Load object attribute from the hdf5 group/data
Arguments:
grp {hdf5 group} -- the current group in the hdf5 architecture
parent_obj {object} -- parent object
grp_name {str} -- name of the group
"""
for child_grp_name, child_grp in grp.items():
if isgroup(child_grp):
load_group(child_grp, parent_obj, child_grp_name)
else:
load_data(child_grp, parent_obj, child_grp_name)
[docs]def load_group(grp, parent_obj, grp_name):
"""Load object attribute from the hdf5 group
Arguments:
grp {hdf5 group} -- the current group in the hdf5 architecture
parent_obj {object} -- parent object
grp_name {str} -- name of the group
"""
try:
if not hasattr(parent_obj, grp_name):
parent_obj.__setattr__(
grp_name, SimpleNamespace())
load_object(grp,
parent_obj.__getattribute__(
grp_name),
grp_name)
except:
print_load_error(grp_name)
[docs]def load_data(grp, parent_obj, grp_name):
"""Load data from the hdf5 data
Arguments:
grp {hdf5 group} -- the current group in the hdf5 architecture
parent_obj {object} -- parent object
grp_name {str} -- name of the group
"""
try:
parent_obj.__setattr__(grp_name,
cast_loaded_data(grp[()]))
except:
print_load_error(grp_name)
[docs]def cast_loaded_data(data):
"""cast the data before loading."""
cast_fn = {bytes: bytes2str}
if type(data) in cast_fn:
data = cast_fn[type(data)](data)
return data
[docs]def bytes2str(bstr):
"""Convert a bytes into string."""
if type(bstr) is bytes:
return bstr.decode('utf-8')
elif type(bstr) is str:
return bstr
else:
raise TypeError(
bstr, ' should be a bytes or str but got ', type(bstr), ' instead')
[docs]def lookup_cast(ori_type, current_type):
raise NotImplementedError(
"cast the data to the type contained in .attrs['type']")
[docs]def isgroup(grp):
"""Check if current hdf5 group is a group
Arguments:
grp {hdf5 group} -- hdf5 group or dataset
Returns:
bool -- True if the group is a group
"""
return type(grp) == h5py._hl.group.Group or type(grp) == h5py._hl.files.File
[docs]def dump_to_hdf5(obj, fname, root_name=None):
"""Dump the content of an object in a hdf5 file.
Arguments:
obj {object} -- object to dump
fname {str} -- name of the hdf5
Keyword Arguments:
root_name {str} -- root group in the hdf5 file (default: {None})
"""
h5 = h5py.File(fname, 'a')
if root_name is None:
root_name = obj.__class__.__name__
# change root name if that name is already present in the file
if root_name in h5:
log.info('')
log.info(' Warning : dump to hdf5')
log.info(
' Object {obj} already exists in {parent}', obj=root_name, parent=fname)
n = sum(1 for n in h5 if n.startswith(root_name)) + 1
root_name = root_name + '_' + str(n)
log.info(
' Object name changed to {obj}', obj=root_name)
log.info('')
insert_object(obj, h5, root_name)
h5.close()
return root_name
[docs]def insert_object(obj, parent_grp, obj_name):
"""Insert the content of the object in the hdf5 file
Arguments:
obj {object} -- object to save
parent_grp {hdf5 group} -- group where to dump
obj_name {str} -- name of the object
"""
if haschildren(obj):
insert_group(obj, parent_grp, obj_name)
else:
insert_data(obj, parent_grp, obj_name)
[docs]def insert_group(obj, parent_grp, obj_name):
"""Insert the content of the object in a hdf5 group
Arguments:
obj {object} -- object to save
parent_grp {hdf5 group} -- group where to dump
obj_name {str} -- name of the object
"""
# ignore object starting with underscore
# a lot of pytorch internal are like that
if obj_name.startswith('_'):
log.debug(
' Warning : Object {obj} not stored in {parent}', obj=obj_name, parent=parent_grp)
log.debug(
' : because object name starts with "_"')
return
# store if the object name is not in parent
if obj_name not in parent_grp:
try:
own_grp = parent_grp.create_group(obj_name)
for child_name in get_children_names(obj):
child_obj = get_child_object(obj, child_name)
insert_object(child_obj, own_grp, child_name)
except Exception as inst:
print(type(inst))
print(inst)
print_insert_error(obj, obj_name)
# if something went wrong anyway
else:
log.critical(
' Warning : Object {obj} already exists in {parent}', obj=obj_name, parent=parent_grp)
log.critical(
' Warning : Keeping original version of the data')
[docs]def insert_data(obj, parent_grp, obj_name):
"""Insert the content of the object in a hdf5 dataset
Arguments:
obj {object} -- object to save
parent_grp {hdf5 group} -- group where to dump
obj_name {str} -- name of the object
"""
if obj_name.startswith('_'):
return
try:
lookup_insert = {list: insert_list,
tuple: insert_tuple,
np.ndarray: insert_numpy,
torch.Tensor: insert_torch_tensor,
torch.nn.parameter.Parameter: insert_torch_parameter,
torch.device: insert_none,
type(None): insert_none}
insert_fn = lookup_insert[type(obj)]
except KeyError:
insert_fn = insert_default
try:
insert_fn(obj, parent_grp, obj_name)
# insert_type(obj, parent_grp, obj_name)
except Exception as expt_message:
print(expt_message)
print_insert_error(obj, obj_name)
[docs]def insert_type(obj, parent_grp, obj_name):
"""Insert the content of the type object in an attribute
Arguments:
obj {object} -- object to save
parent_grp {hdf5 group} -- group where to dump
obj_name {str} -- name of the object
"""
try:
parent_grp[obj_name].attrs['type'] = str(type(obj))
except:
print_insert_type_error(obj, obj_name)
[docs]def insert_default(obj, parent_grp, obj_name):
"""Default funtion to insert a dataset
Arguments:
obj {object} -- object to save
parent_grp {hdf5 group} -- group where to dump
obj_name {str} -- name of the object
"""
try:
parent_grp.create_dataset(obj_name, data=obj)
except Exception as expt_message:
print(expt_message)
print_insert_error(obj, obj_name)
[docs]def insert_list(obj, parent_grp, obj_name):
"""funtion to insert a list as a dataset
Arguments:
obj {object} -- object to save
parent_grp {hdf5 group} -- group where to dump
obj_name {str} -- name of the object
"""
try:
if np.all([isinstance(el,torch.Tensor) for el in obj]):
obj = [el.numpy() for el in obj]
parent_grp.create_dataset(obj_name, data=np.array(obj))
except:
for il, l in enumerate(obj):
try:
insert_object(l, parent_grp, obj_name+'_'+str(il))
except:
print_insert_error(obj, obj_name)
[docs]def insert_tuple(obj, parent_grp, obj_name):
"""funtion to insert a tuple as a dataset
Arguments:
obj {object} -- object to save
parent_grp {hdf5 group} -- group where to dump
obj_name {str} -- name of the object
"""
insert_list(list(obj), parent_grp, obj_name)
[docs]def insert_numpy(obj, parent_grp, obj_name):
"""funtion to insert a numpy array as a dataset
Arguments:
obj {object} -- object to save
parent_grp {hdf5 group} -- group where to dump
obj_name {str} -- name of the object
"""
if obj.dtype.str.startswith('<U'):
obj = obj.astype('S')
insert_default(obj, parent_grp, obj_name)
[docs]def insert_torch_tensor(obj, parent_grp, obj_name):
"""funtion to insert a torch tensor as a dataset
Arguments:
obj {object} -- object to save
parent_grp {hdf5 group} -- group where to dump
obj_name {str} -- name of the object
"""
insert_numpy(obj.cpu().detach().numpy(), parent_grp, obj_name)
[docs]def insert_torch_parameter(obj, parent_grp, obj_name):
"""funtion to insert a torch parameter as a dataset
Arguments:
obj {object} -- object to save
parent_grp {hdf5 group} -- group where to dump
obj_name {str} -- name of the object
"""
insert_torch_tensor(obj.data.cpu(), parent_grp, obj_name)
[docs]def insert_none(obj, parent_grp, obj_name):
"""funtion to insert a None Type as a dataset
Arguments:
obj {object} -- object to save
parent_grp {hdf5 group} -- group where to dump
obj_name {str} -- name of the object
"""
return
[docs]def haschildren(obj):
"""Check if the object has children
Arguments:
obj {object} -- the object to check
Returns:
bool -- True if the object has children
"""
ommit_type = [torch.nn.parameter.Parameter, torch.Tensor]
if type(obj) in ommit_type:
return False
else:
return hasattr(obj, '__dict__') or hasattr(obj, 'keys')
[docs]def children(obj):
"""Returns the children of the object as items
Arguments:
obj {object} -- the object to check
Returns:
dict -- items
"""
if hasattr(obj, '__dict__'):
return obj.__dict__.items()
elif hasattr(obj, 'keys'):
return obj.items()
[docs]def get_children_names(obj):
"""Returns the children names of the object as items
Arguments:
obj {object} -- the object to check
Returns:
dict -- items
"""
if hasattr(obj, '__dict__'):
names = list(obj.__dict__.keys())
elif hasattr(obj, 'keys'):
names = list(obj.keys())
if hasattr(obj, '__extra_attr__'):
names += obj.__extra_attr__
if hasattr(obj, 'state_dict'):
names += list(obj.state_dict().keys())
return list(set(names))
[docs]def get_child_object(obj, child_name):
"""Return the child object
Arguments:
obj {object} -- parent object
child_name {str} -- cild name
Returns:
object -- child object
"""
if hasattr(obj, '__getattr__'):
try:
return obj.__getattr__(child_name)
except AttributeError:
pass
if hasattr(obj, '__getattribute__'):
try:
return obj.__getattribute__(child_name)
except AttributeError:
pass
if hasattr(obj, '__getitem__'):
try:
return obj.__getitem__(child_name)
except AttributeError:
pass
[docs]def add_group_attr(filename, grp_name, attr):
"""Add attribute to a given group
Arguments:
filename {str} -- name of the file
grp_name {str} -- name of the group
attr {dict} -- attrivutes to add
"""
h5 = h5py.File(filename, 'a')
for k, v in attr.items():
h5[grp_name].attrs[k] = v
h5.close()