Source code for gemseo.caches.base_full_cache

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""A base class for caches storing all data."""

from __future__ import annotations

from abc import abstractmethod
from itertools import chain
from multiprocessing import RLock
from multiprocessing import Value
from typing import TYPE_CHECKING
from typing import ClassVar
from typing import Literal
from typing import cast
from typing import overload

from numpy import append
from numpy import array
from numpy import concatenate
from numpy import vstack

from gemseo.caches.base_cache import BaseCache
from gemseo.caches.cache_entry import CacheEntry
from gemseo.caches.utils import hash_data
from gemseo.utils.data_conversion import flatten_nested_bilevel_dict
from gemseo.utils.ggobi_export import save_data_arrays_to_xml
from gemseo.utils.locks import synchronized
from gemseo.utils.locks import synchronized_hashes
from gemseo.utils.multiprocessing.manager import get_multi_processing_manager

if TYPE_CHECKING:
    from collections.abc import Iterable
    from collections.abc import Iterator
    from multiprocessing.managers import DictProxy
    from multiprocessing.sharedctypes import Synchronized
    from multiprocessing.synchronize import RLock as RLockType

    from gemseo.typing import IntegerArray
    from gemseo.typing import JacobianData
    from gemseo.typing import StrKeyMapping


[docs] class BaseFullCache(BaseCache): """Base cache to store all the data, either in memory or on the disk. See Also: :class:`.MemoryFullCache`: store all the data in memory. :class:`.HDF5Cache`: store all the data in an HDF5 file. """ _JACOBIAN_SEPARATOR: ClassVar[str] = "!d$_$d!" """The string separating the input and output names in a derivative name. E.g. ``"output!d$_$d!input"``. """ lock: RLockType """The lock used for both multithreading and multiprocessing. Ensure safe multiprocessing and multithreading concurrent access to the cache. """ lock_hashes: RLockType """The lock used for both multithreading and multiprocessing. Ensure safe multiprocessing and multithreading concurrent access to the cache. """ _hashes_to_indices: DictProxy[int, IntegerArray] """The indices associated with the hashes.""" _max_index: Synchronized[int] """The maximum index of the data stored in the cache.""" _last_accessed_index: Synchronized[int] """The index of the last accessed data.""" def __init__( # noqa: D107 self, tolerance: float = 0.0, name: str = "", ) -> None: super().__init__(tolerance, name) self.lock_hashes = RLock() self._hashes_to_indices = get_multi_processing_manager().dict() self._max_index = cast("Synchronized[int]", Value("i", 0)) self._last_accessed_index = cast("Synchronized[int]", Value("i", 0)) self.lock = self._set_lock() @abstractmethod def _set_lock(self) -> RLockType: """Set a lock for multithreading. Either from an external object or internally by using RLock(). """ def __ensure_input_data_exists( self, input_data: StrKeyMapping, ) -> bool: """Ensure ``input_data`` associated with ``data_hash`` exists. If ``input_data`` is cached, return ``True``. If ``data_hash`` is missing, store this hash and index ``input_data`` before caching later at this index. If ``data_hash`` exists but ``input_data`` is not cached, add ``data_hash`` and then index ``input_data``. Args: input_data: The input data to cache. Returns: Whether ``input_data`` exists. """ data_hash = hash_data(input_data) # Check if there is an entry with this hash in the cache. indices = self._hashes_to_indices.get(data_hash) # If no, initialize a new entry. if indices is None: self._max_index.value += 1 self._last_accessed_index.value = self._max_index.value self._hashes_to_indices[data_hash] = array([self._max_index.value]) self._initialize_entry(self._max_index.value) return False # If yes, look if there is a corresponding input data equal to ``input_data``. for index in indices: if self.compare_dict_of_arrays( input_data, self._read_data(index, self.Group.INPUTS) ): # The input data is already cached => we don't store it again. self._last_accessed_index.value = index return True # If there is no an input data equal ``input_data``, # update the indices related to the ``data_hash``. self._max_index.value += 1 self._last_accessed_index.value = self._max_index.value self._hashes_to_indices[data_hash] = append(indices, self._max_index.value) self._initialize_entry(self._max_index.value) return False def _initialize_entry( self, index: int, ) -> None: """Initialize an entry of the cache if needed. Args: index: The index of the entry. """ @abstractmethod def _has_group( self, index: int, group: BaseCache.Group, ) -> bool: """Check if an entry has data corresponding to a given group. Args: index: The index of the entry. group: The group. Returns: Whether the entry has data for this group. """ @abstractmethod def _write_data( self, values: StrKeyMapping, group: BaseCache.Group, index: int, ) -> None: """Write the data associated with an index and a group. Args: values: The data containing the values of the names to cache. group: The group. index: The index of the entry in the cache. """ def _cache_inputs( self, input_data: StrKeyMapping, group: BaseCache.Group, ) -> bool: """Cache input data and increment group if needed. Cache inputs and increment group if needed. Check if ``group`` exists for these inputs. This method avoids duplicate storage. Args: input_data: The data containing the input data to cache. group: The group. Returns: Whether ``group`` exists. """ if not self.__ensure_input_data_exists(input_data): self._write_data(input_data, self.Group.INPUTS, self._max_index.value) elif self._has_group(self._last_accessed_index.value, group): return True return False
[docs] @synchronized def cache_outputs( # noqa: D102 self, input_data: StrKeyMapping, output_data: StrKeyMapping, ) -> None: if self._cache_inputs(input_data, self.Group.OUTPUTS): # There is already an output data corresponding to this input data. return self._write_data( output_data, self.Group.OUTPUTS, self._last_accessed_index.value, )
[docs] @synchronized def cache_jacobian( # noqa: D102 self, input_data: StrKeyMapping, jacobian_data: JacobianData, ) -> None: if self._cache_inputs(input_data, self.Group.JACOBIAN): # There is already a Jacobian data corresponding to this input data. return flat_jacobian_data = flatten_nested_bilevel_dict( jacobian_data, separator=self._JACOBIAN_SEPARATOR ) self._write_data( flat_jacobian_data, self.Group.JACOBIAN, self._last_accessed_index.value, )
[docs] @synchronized def clear(self) -> None: # noqa: D102 super().clear() self._hashes_to_indices.clear() self._max_index.value = 0 self._last_accessed_index.value = 0
@property @synchronized def last_entry(self) -> CacheEntry: # noqa: D102 if not self: return CacheEntry({}, {}, {}) return CacheEntry( self._read_data(self._last_accessed_index.value, self.Group.INPUTS), self._read_data(self._last_accessed_index.value, self.Group.OUTPUTS), self._read_data(self._last_accessed_index.value, self.Group.JACOBIAN), ) @synchronized def __len__(self) -> int: return self._max_index.value @overload def _read_data( self, index: int, group: Literal[BaseCache.Group.INPUTS, BaseCache.Group.OUTPUTS], ) -> StrKeyMapping: ... @overload def _read_data( self, index: int, group: Literal[BaseCache.Group.JACOBIAN], ) -> JacobianData: ... @abstractmethod def _read_data( self, index: int, group: BaseCache.Group, ) -> StrKeyMapping | JacobianData: """Read the data of an entry. Args: index: The index of the entry. group: The group. Returns: The output and Jacobian data corresponding to these index and group. """ @synchronized_hashes def __has_hash( self, data_hash: int, ) -> IntegerArray | None: """Get the indices corresponding to a data hash. Args: The data hash. Returns: The indices corresponding to this data hash. """ return self._hashes_to_indices.get(data_hash) def _read_input_output_data( self, indices: Iterable[int], input_data: StrKeyMapping, ) -> CacheEntry: """Read the output and Jacobian data for a given input data. Args: indices: The indices of the entries among from which the entry to read data. input_data: The input data. Returns: The output and Jacobian data if they exist, ``None`` otherwise. """ for index in indices: if self.compare_dict_of_arrays( input_data, self._read_data(index, self.Group.INPUTS) ): output_data = self._read_data(index, self.Group.OUTPUTS) jacobian_data = self._read_data(index, self.Group.JACOBIAN) return CacheEntry(input_data, output_data, jacobian_data) return CacheEntry(input_data, {}, {}) @synchronized def __getitem__( self, input_data: StrKeyMapping, ) -> CacheEntry: if self._tolerance == 0.0: data_hash = hash_data(input_data) indices = self.__has_hash(data_hash) if indices is None: return CacheEntry(input_data, {}, {}) return self._read_input_output_data(indices, input_data) for indices in self._hashes_to_indices.values(): for index in indices: cached_input_data = self._read_data(index, self.Group.INPUTS) if self.compare_dict_of_arrays( input_data, cached_input_data, self._tolerance ): output_data = self._read_data(index, self.Group.OUTPUTS) jacobian_data = self._read_data(index, self.Group.JACOBIAN) return CacheEntry(input_data, output_data, jacobian_data) return CacheEntry(input_data, {}, {}) @property def _all_groups(self) -> list[int]: """Sorted the indices of the entries.""" return sorted(chain(*(v.tolist() for v in self._hashes_to_indices.values())))
[docs] @synchronized def get_all_entries(self) -> Iterator[CacheEntry]: # noqa: D102 for index in self._all_groups: input_data = self._read_data(index, self.Group.INPUTS) output_data = self._read_data(index, self.Group.OUTPUTS) jacobian_data = self._read_data(index, self.Group.JACOBIAN) yield CacheEntry(input_data, output_data, jacobian_data)
[docs] def to_ggobi( self, file_path: str, input_names: Iterable[str] = (), output_names: Iterable[str] = (), ) -> None: """Export the cache to an XML file for ggobi tool. Args: file_path: The path of the file to export the cache. input_names: The names of the inputs to export. If empty, export all of them. output_names: The names of the outputs to export. If empty, export all of them. """ if not self._hashes_to_indices: msg = "An empty cache cannot be exported to XML file." raise ValueError(msg) shared_input_names: set[str] = set() shared_output_names: set[str] = set() all_input_data = [] all_output_data = [] names_to_sizes = {} for data in self.get_all_entries(): input_data = data.inputs or {} output_data = data.outputs or {} try: if input_names: input_data = {name: input_data[name] for name in input_names} if output_names: output_data = {name: output_data[name] for name in output_names} except KeyError: # The data is not in this execution continue # Compute the size of the data names_to_sizes.update({key: val.size for key, val in input_data.items()}) names_to_sizes.update({key: val.size for key, val in output_data.items()}) current_input_names = set(input_data.keys()) current_output_names = set(output_data.keys()) shared_input_names = ( shared_input_names or current_input_names ) & current_input_names shared_output_names = ( shared_output_names or current_output_names ) & current_output_names all_input_data.append(input_data) all_output_data.append(output_data) if not all_output_data: msg = "Failed to find outputs in the cache." raise ValueError(msg) variable_names = [] for data_name in list(shared_input_names) + list(shared_output_names): data_size = names_to_sizes[data_name] if data_size == 1: variable_names.append(data_name) else: variable_names += [f"{data_name}_{i + 1}" for i in range(data_size)] cache_as_array = vstack([ concatenate( [all_input_data[index][name].flatten() for name in shared_input_names] + [ all_output_data[index][name].flatten() for name in shared_output_names ] ) for index in range(len(all_input_data)) ]) save_data_arrays_to_xml(variable_names, cache_as_array, file_path)
[docs] def update( self, other_cache: BaseFullCache, ) -> None: """Update from another cache. Args: other_cache: The cache to update the current one. """ for input_data, output_data, jacobian_data in other_cache.get_all_entries(): if output_data or jacobian_data: self[input_data] = (output_data, jacobian_data)
@abstractmethod def _copy_empty_cache(self) -> BaseFullCache: """Copy a cache without its entries.""" def __add__( self, other_cache: BaseFullCache, ) -> BaseFullCache: """Concatenate a cache to a copy of the current one. Args: other_cache: A cache to be concatenated to a copy of the current one. Returns: A new cache concatenating the current one and ``other_cache``. """ new_cache = self._copy_empty_cache() new_cache.update(self) new_cache.update(other_cache) return new_cache