Source code for gemseo.caches.memory_full_cache

# -*- coding: utf-8 -*-
# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

# Contributors:
#    INITIAL AUTHORS - initial API and implementation and/or initial
#                         documentation
#        :author: Francois Gallard, Matthias De Lozzo
#    OTHER AUTHORS   - MACROSCOPIC CHANGES
"""
Caching module to avoid multiple evaluations of a discipline
************************************************************
"""
from __future__ import division, unicode_literals

import logging

from gemseo.core.cache import AbstractFullCache
from gemseo.utils.data_conversion import DataConversion
from gemseo.utils.locks import synchronized
from gemseo.utils.multi_processing import RLock

LOGGER = logging.getLogger(__name__)


[docs]class MemoryFullCache(AbstractFullCache): """Cache using memory to cache all data.""" def __init__(self, tolerance=0.0, name=None, is_memory_shared=True): """Initialize a dictionary to cache data. Initialize cache tolerance. By default, don't use approximate cache. It is up to the user to choose to optimize CPU time with this or not could be something like 2 * finfo(float).eps Parameters ---------- tolerance : float Tolerance that defines if two input vectors are equal and cached data shall be returned. If 0, no approximation is made. Default: 0. name : str Name of the cache. is_memory_shared : bool If True, a shared memory dict is used to store the data, which makes the cache compatible with multiprocessing. WARNING: if set to False, and multiple disciplines point to the same cache or the process is multiprocessed, there may be duplicate computations because the cache will not be shared among the processes. Examples -------- >>> from gemseo.caches.memory_full_cache import MemoryFullCache >>> cache = MemoryFullCache() """ self.__is_memory_shared = is_memory_shared super(MemoryFullCache, self).__init__(tolerance, name) self.__init_data() def __init_data(self): """Initializes the local dict that stores the data. Either a shared memory dict or a basic dict. """ if self.__is_memory_shared: self._data = self._manager.dict() else: self._data = {} def _duplicate_from_scratch(self): return MemoryFullCache(self.tolerance, self.name, self.__is_memory_shared) def _initialize_entry(self, sample_id): """Initialize an entry of the dataset if needed. :param int sample_id: sample ID. """ template = {} self._data[sample_id] = template def _set_lock(self): """Sets a lock for multithreading, either from an external object or internally by using RLock().""" return RLock() def _has_group(self, sample_id, var_group): """Checks if the dataset has the particular variables group filled in. :param int sample_id: sample ID. :param str var_group: name of the variables group. :return: True if the variables group is filled in. :rtype: bool """ return var_group in self._data.get(sample_id)
[docs] @synchronized def clear(self): """Clear the cache. Examples -------- >>> from gemseo.caches.memory_full_cache import MemoryFullCache >>> from numpy import array >>> cache = MemoryFullCache() >>> for index in range(5): >>> data = {'x': array([1.])*index, 'y': array([.2])*index} >>> cache.cache_outputs(data, ['x'], data, ['y']) >>> cache.get_length() 5 >>> cache.clear() >>> cache.get_length() 0 """ super(MemoryFullCache, self).clear() self.__init_data()
def _read_data(self, group_number, group_name): """Read a data dict in the hdf. :param group_name: name of the group where data is written :param group_number: number of the group :returns: data dict and jacobian """ result = self._data[group_number].get(group_name) if group_name == self.JACOBIAN_GROUP and result is not None: result = DataConversion.dict_to_jac_dict(result) return result def _write_data(self, values, names, var_group, sample_id): """Writes data associated with a variables group and a sample ID into the dataset. :param dict values: data dictionary where keys are variables names and values are variables values (numpy arrays). :param list(str) names: list of input data names to write. :param str var_group: name of the variables group, either AbstractCache.INPUTS_GROUP, AbstractCache.OUTPUTS_GROUP or AbstractCache.JACOBIAN_GROUP. :param int sample_id: sample ID. """ data = self._data[sample_id] data[var_group] = {name: values[name] for name in names} self._data[sample_id] = data @property def copy(self): """Copy cache.""" cache = self._duplicate_from_scratch() cache.merge(self) return cache