Source code for gemseo.mlearning.data_formatters.moe_data_formatters

# Copyright 2021 IRT Saint Exupéry, https://www.irt-saintexupery.com
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License version 3 as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
"""Data formatters for mixture of experts."""

from __future__ import annotations

import functools
from collections.abc import Mapping
from typing import TYPE_CHECKING

from gemseo.mlearning.data_formatters.regression_data_formatters import (
    RegressionDataFormatters,
)

if TYPE_CHECKING:
    from typing import Any
    from typing import Callable

    from numpy import ndarray

    from gemseo.mlearning.core.ml_algo import DataType
    from gemseo.mlearning.regression.moe import MOERegressor
    from gemseo.typing import RealArray
from gemseo.utils.data_conversion import concatenate_dict_of_arrays_to_array


[docs] class MOEDataFormatters(RegressionDataFormatters): """Data formatters for mixture of experts."""
[docs] @classmethod def format_predict_class_dict( cls, func: Callable[[MOERegressor, RealArray, Any, ...], ndarray], ) -> Callable[[MOERegressor, DataType, Any, ...], DataType]: """Make an array-based function be called with a dictionary of NumPy arrays. Args: func: The function to be called; it takes a NumPy array in input and returns a NumPy array. Returns: A function making a function work with either a NumPy data array or a dictionary of NumPy data arrays indexed by variables names. The evaluation will have the same type as the input data. """ @functools.wraps(func) def wrapper( algo: MOERegressor, input_data: DataType, *args: Any, **kwargs: Any, ) -> DataType: """Evaluate ``func`` with either array or dictionary-based input data. Firstly, the pre-processing stage converts the input data to a NumPy data array, if these data are expressed as a dictionary of NumPy data arrays. Then, the processing evaluates the function ``func`` from this NumPy input data array. Lastly, the post-processing transforms the output data to a dictionary of output NumPy data array if the input data were passed as a dictionary of NumPy data arrays. Args: algo: The mixture of experts. input_data: The input data. *args: The positional arguments of the function ``func``. **kwargs: The keyword arguments of the function ``func``. Returns: The output data with the same type as the input one. """ as_dict = isinstance(input_data, Mapping) if as_dict: input_data = concatenate_dict_of_arrays_to_array( input_data, algo.input_names ) output_data = func(algo, input_data, *args, **kwargs) if as_dict: output_data = {algo.LABELS: output_data} return output_data return wrapper