# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
# Contributors:
# INITIAL AUTHORS - initial API and implementation and/or initial
# documentation
# :author: Francois Gallard, Matthias De Lozzo
# OTHER AUTHORS - MACROSCOPIC CHANGES
"""HDF5 file singleton used by the HDF5 cache."""
from __future__ import annotations
from contextlib import contextmanager
from multiprocessing import RLock
from pathlib import Path
from typing import TYPE_CHECKING
from typing import ClassVar
import h5py
from genericpath import exists
from h5py import File
from numpy import append
from numpy import bytes_
from numpy import str_
from numpy.core.multiarray import array
from scipy.sparse import csr_array
from strenum import StrEnum
from gemseo.core.cache import AbstractCache
from gemseo.core.cache import hash_data_dict
from gemseo.core.cache import to_real
from gemseo.utils.compatibility.scipy import SparseArrayType
from gemseo.utils.compatibility.scipy import sparse_classes
from gemseo.utils.singleton import SingleInstancePerFileAttribute
if TYPE_CHECKING:
from collections.abc import Iterator
from multiprocessing.managers import DictProxy
from multiprocessing.synchronize import RLock as RLockType
from gemseo.typing import DataMapping
from gemseo.typing import IntegerArray
# TODO: API: make this module and class protected.
[docs]
class HDF5FileSingleton(metaclass=SingleInstancePerFileAttribute):
"""Singleton to access an HDF file.
Used for multithreading/multiprocessing access with a lock.
"""
# We create a single instance of cache per HDF5 file
# to allow the multiprocessing lock to work
# Ie we want a single lock even if we instantiate multiple
# caches
hdf_file_path: str
"""The path to the HDF file."""
lock: RLockType
"""The lock used for multithreading."""
HASH_TAG: ClassVar[str] = "hash"
"""The label for the hash."""
FILE_FORMAT_VERSION: ClassVar[int] = 2
"""The version of the file format."""
__keep_open: bool
"""Whether to keep the file open when leaving :meth:`.__open` context manager."""
__file: File | None
"""The hdf5 file handle."""
class __SparseMatricesAttribute(StrEnum): # noqa: N801
"""Attribute name required to store sparse matrices in CSR format."""
SPARSE = "sparse"
INDICES = "indices"
INDPTR = "indptr"
SHAPE = "shape"
def __init__(
self,
hdf_file_path: str,
) -> None:
"""
Args:
hdf_file_path: The path to the HDF file.
""" # noqa: D205, D212, D415
self.__keep_open = False
self.__file = None
self.hdf_file_path = hdf_file_path
self.__check_file_format_version()
# Attach the lock to the file and NOT the Cache because it is a singleton.
self.lock = RLock()
[docs]
def write_data(
self,
data: DataMapping,
group: AbstractCache.Group,
index: int,
hdf_node_path: str,
) -> None:
"""Cache input data to avoid re-evaluation.
Args:
data: The data containing the values of the names to cache.
group: The group.
index: The index of the entry in the cache.
hdf_node_path: The name of the HDF group to store the entries,
possibly passed as a path ``root_name/.../group_name/.../node_name``.
"""
with self.__open(mode="a"):
assert self.__file is not None
if not len(self.__file):
self.__set_file_format_version(self.__file)
root = self.__file.require_group(hdf_node_path)
entry = root.require_group(str(index))
entry_group = entry.require_group(group)
try:
# Write hash if needed
if entry.get(self.HASH_TAG) is None:
data_hash = array([hash_data_dict(data)], dtype="bytes")
entry.create_dataset(self.HASH_TAG, data=data_hash)
for name, value in data.items():
value = data.get(name)
if value is not None:
if value.dtype.type is str_:
entry_group.create_dataset(name, data=value.astype("bytes"))
elif isinstance(value, sparse_classes):
self.__write_sparse_array(entry_group, name, value)
else:
entry_group.create_dataset(name, data=to_real(value))
# IOError and RuntimeError are for python 2.7
except (RuntimeError, OSError, ValueError):
msg = "Failed to cache dataset %s.%s.%s in file: %s"
raise RuntimeError(
msg,
hdf_node_path,
index,
entry_group,
self.hdf_file_path,
) from None
def __write_sparse_array(
self,
group: h5py.Group,
dataset_name: str,
value: SparseArrayType,
) -> None:
"""Store sparse array in HDF5 group.
Args:
group: The group.
dataset_name: The name of the dataset to store the sparse array in.
value: The sparse array.
"""
value = value.tocsr()
# Create a dataset containing the non-zero elements of value
dataset = group.create_dataset(dataset_name, data=to_real(value.data))
# Add as attribute the indices, pointers and shape required for reconstruction
dataset.attrs.create(self.__SparseMatricesAttribute.INDICES, value.indices)
dataset.attrs.create(self.__SparseMatricesAttribute.INDPTR, value.indptr)
dataset.attrs.create(self.__SparseMatricesAttribute.SHAPE, value.shape)
# Add a sparse flag
dataset.attrs.create(self.__SparseMatricesAttribute.SPARSE, True)
[docs]
def read_data(
self,
index: int,
group: AbstractCache.Group,
hdf_node_path: str,
) -> DataMapping:
"""Read the data for given index and group.
Args:
index: The index of the entry.
group: The group.
hdf_node_path: The name of the HDF group where the entries are stored,
possibly passed as a path ``root_name/.../group_name/.../node_name``.
Returns:
The group data and the input data hash.
Raises:
ValueError: If the group cannot be found.
"""
with self.__open():
assert self.__file is not None
root = self.__file[hdf_node_path]
if not self._has_group(index, group, hdf_node_path):
return {}
entry = root[str(index)]
data = {}
for key, dataset in entry[group].items():
if dataset.attrs.get(self.__SparseMatricesAttribute.SPARSE):
data[key] = self.__read_sparse_array(dataset)
else:
data[key] = array(dataset)
for name, value in data.items():
if value.dtype.type is bytes_:
data[name] = value.astype(str_)
return data
def __read_sparse_array(self, dataset: h5py.Dataset) -> csr_array:
"""Read sparse array from a HDF5 dataset.
Args:
dataset: The dataset to the read the data from.
Returns:
The sparse array in CSR format.
"""
indices = dataset.attrs.get(self.__SparseMatricesAttribute.INDICES)
indptr = dataset.attrs.get(self.__SparseMatricesAttribute.INDPTR)
shape = dataset.attrs.get(self.__SparseMatricesAttribute.SHAPE)
return csr_array((dataset, indices, indptr), shape)
def _has_group(
self,
index: int,
group: AbstractCache.Group,
hdf_node_path: str,
) -> bool:
"""Return whether a group exists.
Args:
index: The index of the entry.
group: The group.
hdf_node_path: The name of the HDF group where the entries are stored,
possibly passed as a path ``root_name/.../group_name/.../node_name``.
Returns:
Whether a group exists.
"""
assert self.__file is not None
entry = self.__file[hdf_node_path].get(str(index))
if entry is None:
return False
return entry.get(group) is not None
[docs]
def has_group(
self,
index: int,
group: AbstractCache.Group,
hdf_node_path: str,
) -> bool:
"""Check if an entry has data corresponding to a given group.
Args:
index: The index of the entry.
group: The group.
hdf_node_path: The name of the HDF group where the entries are stored,
possibly passed as a path ``root_name/.../group_name/.../node_name``.
Returns:
Whether the entry has data for this group.
"""
with self.__open():
return self._has_group(index, group, hdf_node_path)
[docs]
def read_hashes(
self,
hashes_to_indices: DictProxy[int, IntegerArray],
hdf_node_path: str,
) -> int:
"""Read the hashes in the HDF file.
Args:
hashes_to_indices: The indices associated to the hashes.
hdf_node_path: The name of the HDF group where the entries are stored,
possibly passed as a path ``root_name/.../group_name/.../node_name``.
Returns:
The maximum index.
"""
if not exists(self.hdf_file_path):
return 0
# We must lock so that no data is added to the cache meanwhile
with self.__open():
assert self.__file is not None
root = self.__file.get(hdf_node_path)
if root is None:
return 0
max_index = 0
for index, entry in root.items():
index = int(index)
hash_ = int(array(entry[self.HASH_TAG])[0])
indices = hashes_to_indices.get(hash_)
if indices is None:
hashes_to_indices[hash_] = array([index])
else:
hashes_to_indices[hash_] = append(indices, array([index]))
max_index = max(max_index, index)
return max_index
[docs]
def clear(
self,
hdf_node_path: str,
) -> None:
"""Clear a node.
Args:
hdf_node_path: The name of the HDF group to clear,
possibly passed as a path ``root_name/.../group_name/.../node_name``.
"""
with self.__open(mode="a"):
assert self.__file is not None
del self.__file[hdf_node_path]
def __check_file_format_version(self) -> None:
"""Make sure the file can be handled.
Raises:
ValueError: If the version of the file format is missing or
greater than the current one.
"""
if not Path(self.hdf_file_path).exists():
return
with self.__open():
assert self.__file is not None
if not len(self.__file):
return
version = self.__file.attrs.get("version")
if version is None:
msg = (
f"The file {self.hdf_file_path} cannot be used because it has no file "
"format version: see HDFCache.update_file_format to convert it."
)
raise ValueError(msg)
if version > self.FILE_FORMAT_VERSION:
msg = (
f"The file {self.hdf_file_path} cannot be used because its file "
f"format version is {version}"
f"while the expected version is {self.FILE_FORMAT_VERSION}: "
"see HDFCache.update_file_format to convert it."
)
raise ValueError(msg)
@classmethod
def __set_file_format_version(
cls,
h5_file: h5py.File,
) -> None:
"""Change the version of an HDF5 file to :attr:`.FILE_FORMAT_VERSION`.
Args:
h5_file: A HDF5 file object.
"""
h5_file.attrs["version"] = cls.FILE_FORMAT_VERSION
@contextmanager
def __open(self, mode: str = "r") -> Iterator[None]:
"""Open a hdf5 file.
The file handle is used via :attr:`.__file`.
Args:
mode: The opening mode, see :meth:`h5py.File`.
"""
if self.__keep_open and self.__file is not None:
yield
else:
self.__file = h5py.File(self.hdf_file_path, mode=mode)
yield
if not self.__keep_open:
self.__close()
[docs]
@contextmanager
def keep_open(self) -> Iterator[None]:
"""Keep the file open for all file operations done in this context manager."""
self.__keep_open = True
yield
self.__keep_open = False
self.__close()
def __close(self) -> None:
"""Close the file handle."""
assert self.__file is not None
self.__file.close()
self.__file = None