gemseo.mlearning.resampling.cross_validation module#
A cross-validation tool for resampling and surrogate modeling.
- class CrossValidation(sample_indices, n_folds=5, randomize=False, seed=0)[source]#
Bases:
BaseResamplerA cross-validation tool for resampling and surrogate modeling.
- Parameters:
sample_indices (NDArray[int]) -- The original indices of the samples.
n_folds (int) --
The number of folds.
By default it is set to 5.
randomize (bool) --
Whether the sample indices are shuffled before splitting.
By default it is set to False.
seed (int | None) --
The seed to initialize the random generator. If
None, then fresh, unpredictable entropy will be pulled from the OS.By default it is set to 0.
- execute(model, return_models=False, input_data=None, stack_predictions=True, fit_transformers=True, store_sampling_result=False)[source]#
Apply the resampling technique to a machine learning model.
- Parameters:
model (MLAlgo) -- The machine learning model.
return_models (bool) --
Whether the sub-models resulting from resampling are returned.
By default it is set to False.
input_data (ndarray | None) -- The input data for the prediction, if any.
stack_predictions (bool) --
Whether the sub-predictions are stacked in the order of the
sample_indicespassed at instantiation (first the prediction at indexsample_indices[0], then the prediction at indexsample_indices[1], etc.). This argument is ignored wheninput_dataisNone.By default it is set to True.
fit_transformers (bool) --
Whether to re-fit the transformers.
By default it is set to True.
store_sampling_result (bool) --
Whether to store the sampling results in the attribute
resampling_resultsof the original model.By default it is set to False.
- Returns:
First the sub-models resulting from resampling if
return_modelsisTruethen the predictions, either per fold or stacked.- Raises:
ValueError -- When the model is neither a supervised algorithm nor a clustering one.
- Return type: