gemseo / mlearning / resampling

Show inherited members

cross_validation module

A cross-validation tool for resampling and surrogate modeling.

class gemseo.mlearning.resampling.cross_validation.CrossValidation(sample_indices, n_folds=5, randomize=False, seed=0)[source]

Bases: BaseResampler

A cross-validation tool for resampling and surrogate modeling.

Parameters:
  • sample_indices (NDArray[int]) – The original indices of the samples.

  • n_folds (int) –

    The number of folds.

    By default it is set to 5.

  • randomize (bool) –

    Whether the sample indices are shuffled before splitting.

    By default it is set to False.

  • seed (int | None) –

    The seed to initialize the random generator. If None, then fresh, unpredictable entropy will be pulled from the OS.

    By default it is set to 0.

execute(model, return_models=False, input_data=None, stack_predictions=True, fit_transformers=True, store_sampling_result=False)[source]

Apply the resampling technique to a machine learning model.

Parameters:
  • model (MLAlgo) – The machine learning model.

  • return_models (bool) –

    Whether the sub-models resulting from resampling are returned.

    By default it is set to False.

  • input_data (ndarray | None) – The input data for the prediction, if any.

  • stack_predictions (bool) –

    Whether the sub-predictions are stacked in the order of the sample_indices passed at instantiation (first the prediction at index sample_indices[0], then the prediction at index sample_indices[1], etc.). This argument is ignored when input_data is None.

    By default it is set to True.

  • fit_transformers (bool) –

    Whether to re-fit the transformers.

    By default it is set to True.

  • store_sampling_result (bool) –

    Whether to store the sampling results in the attribute resampling_results of the original model.

    By default it is set to False.

Returns:

First the sub-models resulting from resampling if return_models is True then the predictions, either per fold or stacked.

Raises:

ValueError – When the model is neither a supervised algorithm nor a clustering one.

Return type:

tuple[list[MLAlgo], list[ndarray] | ndarray]

property n_folds: int

The number of folds.

name: str

The name of the resampler.

Use the class name by default.

property randomize: bool

Whether the sample indices are shuffled before splitting.

property shuffled_sample_indices: NDArray[int]

The original indices of the samples.