cross_validation module¶
A cross-validation tool for resampling and surrogate modeling.
- class gemseo.mlearning.resampling.cross_validation.CrossValidation(sample_indices, n_folds=5, randomize=False, seed=0)[source]¶
Bases:
Resampler
A cross-validation tool for resampling and surrogate modeling.
- Parameters:
sample_indices (NDArray[int]) – The original indices of the samples.
n_folds (int) –
The number of folds.
By default it is set to 5.
randomize (bool) –
Whether the sample indices are shuffled before splitting.
By default it is set to False.
seed (int | None) –
The seed to initialize the random generator. If
None
, then fresh, unpredictable entropy will be pulled from the OS.By default it is set to 0.
- execute(model, return_models, predict, stack_predictions, fit_transformers, store_sampling_result, input_data, output_data_shape)¶
Apply the resampling technique to a machine learning model.
- Parameters:
model (MLAlgo) – The machine learning model.
return_models (bool) – Whether the sub-models resulting from resampling are returned.
predict (bool) – Whether the sub-models resulting from sampling do prediction on their corresponding learning data.
stack_predictions (bool) – Whether the sub-predictions are stacked.
fit_transformers (bool) – Whether to re-fit the transformers.
store_sampling_result (bool) – Whether to store the sampling results in the attribute
resampling_results
of the original model.input_data (ndarray) – The input data.
output_data_shape (tuple[int, ...]) – The shape of the output data array.
- Returns:
First the sub-models resulting from resampling if
return_models
isTrue
then the predictions, either per fold or stacked.- Raises:
ValueError – When the model is neither a supervised algorithm nor a clustering one.
- Return type: