Note
Go to the end to download the full example code.
Gaussian process (GP) regression#
A GaussianProcessRegressor
is a GP regression model
based on scikit-learn.
See also
You can find more information about building GP models with scikit-learn on this page.
from __future__ import annotations
from matplotlib import pyplot as plt
from numpy import array
from sklearn.gaussian_process.kernels import RBF
from sklearn.gaussian_process.kernels import Matern
from gemseo import configure_logger
from gemseo import create_design_space
from gemseo import create_discipline
from gemseo import sample_disciplines
from gemseo.mlearning import create_regression_model
configure_logger()
<RootLogger root (INFO)>
Problem#
In this example,
we represent the function \(f(x)=(6x-2)^2\sin(12x-4)\) [FSK08]
by the AnalyticDiscipline
discipline = create_discipline(
"AnalyticDiscipline",
name="f",
expressions={"y": "(6*x-2)**2*sin(12*x-4)"},
)
and seek to approximate it over the input space
input_space = create_design_space()
input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0)
To do this, we create a training dataset with 6 equispaced points:
training_dataset = sample_disciplines(
[discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6
)
WARNING - 12:47:33: No coupling in MDA, switching chain_linearize to True.
INFO - 12:47:33:
INFO - 12:47:33: *** Start Sampling execution ***
INFO - 12:47:33: Sampling
INFO - 12:47:33: Disciplines: f
INFO - 12:47:33: MDO formulation: MDF
INFO - 12:47:33: Running the algorithm PYDOE_FULLFACT:
INFO - 12:47:33: 17%|█▋ | 1/6 [00:00<00:00, 522.72 it/sec]
INFO - 12:47:33: 33%|███▎ | 2/6 [00:00<00:00, 855.98 it/sec]
INFO - 12:47:33: 50%|█████ | 3/6 [00:00<00:00, 1098.08 it/sec]
INFO - 12:47:33: 67%|██████▋ | 4/6 [00:00<00:00, 1295.34 it/sec]
INFO - 12:47:33: 83%|████████▎ | 5/6 [00:00<00:00, 1461.23 it/sec]
INFO - 12:47:33: 100%|██████████| 6/6 [00:00<00:00, 1594.69 it/sec]
INFO - 12:47:33: *** End Sampling execution (time: 0:00:00.005070) ***
Basics#
Training#
Then, we train a GP regression model from these samples:
model = create_regression_model("GaussianProcessRegressor", training_dataset)
model.learn()
Prediction#
Once it is built, we can predict the output value of \(f\) at a new input point:
input_value = {"x": array([0.65])}
output_value = model.predict(input_value)
output_value
{'y': array([2.20380214])}
but cannot predict its Jacobian value:
try:
model.predict_jacobian(input_value)
except NotImplementedError:
print("The derivatives are not available for GaussianProcessRegressor.")
The derivatives are not available for GaussianProcessRegressor.
Uncertainty#
GP models are often valued for their ability to provide model uncertainty. Indeed, a GP model is a random process fully characterized by its mean function and a covariance structure. Given an input point \(x\), the prediction is equal to the mean at \(x\) and the uncertainty is equal to the standard deviation at \(x\):
standard_deviation = model.predict_std(input_value)
standard_deviation
array([[0.3140468]])
Plotting#
You can see that the GP model interpolates the training points but is very bad elsewhere. This case-dependent problem is due to poor auto-tuning of these length scales. We will look at how to correct this next.
test_dataset = sample_disciplines(
[discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100
)
input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy()
reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel()
predicted_output_data = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.grid()
plt.legend()
plt.show()

WARNING - 12:47:33: No coupling in MDA, switching chain_linearize to True.
INFO - 12:47:33:
INFO - 12:47:33: *** Start Sampling execution ***
INFO - 12:47:33: Sampling
INFO - 12:47:33: Disciplines: f
INFO - 12:47:33: MDO formulation: MDF
INFO - 12:47:33: Running the algorithm PYDOE_FULLFACT:
INFO - 12:47:33: 1%| | 1/100 [00:00<00:00, 3305.20 it/sec]
INFO - 12:47:33: 2%|▏ | 2/100 [00:00<00:00, 3066.01 it/sec]
INFO - 12:47:33: 3%|▎ | 3/100 [00:00<00:00, 3067.51 it/sec]
INFO - 12:47:33: 4%|▍ | 4/100 [00:00<00:00, 3077.82 it/sec]
INFO - 12:47:33: 5%|▌ | 5/100 [00:00<00:00, 3091.32 it/sec]
INFO - 12:47:33: 6%|▌ | 6/100 [00:00<00:00, 3037.88 it/sec]
INFO - 12:47:33: 7%|▋ | 7/100 [00:00<00:00, 3067.93 it/sec]
INFO - 12:47:33: 8%|▊ | 8/100 [00:00<00:00, 3106.89 it/sec]
INFO - 12:47:33: 9%|▉ | 9/100 [00:00<00:00, 3145.73 it/sec]
INFO - 12:47:33: 10%|█ | 10/100 [00:00<00:00, 3183.53 it/sec]
INFO - 12:47:33: 11%|█ | 11/100 [00:00<00:00, 3191.35 it/sec]
INFO - 12:47:33: 12%|█▏ | 12/100 [00:00<00:00, 3204.20 it/sec]
INFO - 12:47:33: 13%|█▎ | 13/100 [00:00<00:00, 3226.01 it/sec]
INFO - 12:47:33: 14%|█▍ | 14/100 [00:00<00:00, 3248.34 it/sec]
INFO - 12:47:33: 15%|█▌ | 15/100 [00:00<00:00, 3268.29 it/sec]
INFO - 12:47:33: 16%|█▌ | 16/100 [00:00<00:00, 3287.56 it/sec]
INFO - 12:47:33: 17%|█▋ | 17/100 [00:00<00:00, 3306.58 it/sec]
INFO - 12:47:33: 18%|█▊ | 18/100 [00:00<00:00, 3319.74 it/sec]
INFO - 12:47:33: 19%|█▉ | 19/100 [00:00<00:00, 3316.21 it/sec]
INFO - 12:47:33: 20%|██ | 20/100 [00:00<00:00, 3320.77 it/sec]
INFO - 12:47:33: 21%|██ | 21/100 [00:00<00:00, 3330.95 it/sec]
INFO - 12:47:33: 22%|██▏ | 22/100 [00:00<00:00, 3341.95 it/sec]
INFO - 12:47:33: 23%|██▎ | 23/100 [00:00<00:00, 3353.11 it/sec]
INFO - 12:47:33: 24%|██▍ | 24/100 [00:00<00:00, 3364.64 it/sec]
INFO - 12:47:33: 25%|██▌ | 25/100 [00:00<00:00, 3364.92 it/sec]
INFO - 12:47:33: 26%|██▌ | 26/100 [00:00<00:00, 3367.25 it/sec]
INFO - 12:47:33: 27%|██▋ | 27/100 [00:00<00:00, 3375.45 it/sec]
INFO - 12:47:33: 28%|██▊ | 28/100 [00:00<00:00, 3383.48 it/sec]
INFO - 12:47:33: 29%|██▉ | 29/100 [00:00<00:00, 3391.84 it/sec]
INFO - 12:47:33: 30%|███ | 30/100 [00:00<00:00, 3400.33 it/sec]
INFO - 12:47:33: 31%|███ | 31/100 [00:00<00:00, 3407.59 it/sec]
INFO - 12:47:33: 32%|███▏ | 32/100 [00:00<00:00, 3415.04 it/sec]
INFO - 12:47:33: 33%|███▎ | 33/100 [00:00<00:00, 3422.74 it/sec]
INFO - 12:47:33: 34%|███▍ | 34/100 [00:00<00:00, 3407.97 it/sec]
INFO - 12:47:33: 35%|███▌ | 35/100 [00:00<00:00, 3409.05 it/sec]
INFO - 12:47:33: 36%|███▌ | 36/100 [00:00<00:00, 3412.32 it/sec]
INFO - 12:47:33: 37%|███▋ | 37/100 [00:00<00:00, 3417.14 it/sec]
INFO - 12:47:33: 38%|███▊ | 38/100 [00:00<00:00, 3421.86 it/sec]
INFO - 12:47:33: 39%|███▉ | 39/100 [00:00<00:00, 3422.20 it/sec]
INFO - 12:47:33: 40%|████ | 40/100 [00:00<00:00, 3424.13 it/sec]
INFO - 12:47:33: 41%|████ | 41/100 [00:00<00:00, 3429.04 it/sec]
INFO - 12:47:33: 42%|████▏ | 42/100 [00:00<00:00, 3430.72 it/sec]
INFO - 12:47:33: 43%|████▎ | 43/100 [00:00<00:00, 3432.78 it/sec]
INFO - 12:47:33: 44%|████▍ | 44/100 [00:00<00:00, 3434.82 it/sec]
INFO - 12:47:33: 45%|████▌ | 45/100 [00:00<00:00, 3438.27 it/sec]
INFO - 12:47:33: 46%|████▌ | 46/100 [00:00<00:00, 3442.43 it/sec]
INFO - 12:47:33: 47%|████▋ | 47/100 [00:00<00:00, 3446.25 it/sec]
INFO - 12:47:33: 48%|████▊ | 48/100 [00:00<00:00, 3441.72 it/sec]
INFO - 12:47:33: 49%|████▉ | 49/100 [00:00<00:00, 3444.35 it/sec]
INFO - 12:47:33: 50%|█████ | 50/100 [00:00<00:00, 3448.75 it/sec]
INFO - 12:47:33: 51%|█████ | 51/100 [00:00<00:00, 3450.49 it/sec]
INFO - 12:47:33: 52%|█████▏ | 52/100 [00:00<00:00, 3452.98 it/sec]
INFO - 12:47:33: 53%|█████▎ | 53/100 [00:00<00:00, 3455.80 it/sec]
INFO - 12:47:33: 54%|█████▍ | 54/100 [00:00<00:00, 3454.42 it/sec]
INFO - 12:47:33: 55%|█████▌ | 55/100 [00:00<00:00, 3457.12 it/sec]
INFO - 12:47:33: 56%|█████▌ | 56/100 [00:00<00:00, 3459.58 it/sec]
INFO - 12:47:33: 57%|█████▋ | 57/100 [00:00<00:00, 3462.90 it/sec]
INFO - 12:47:33: 58%|█████▊ | 58/100 [00:00<00:00, 3466.66 it/sec]
INFO - 12:47:33: 59%|█████▉ | 59/100 [00:00<00:00, 3443.02 it/sec]
INFO - 12:47:33: 60%|██████ | 60/100 [00:00<00:00, 3442.33 it/sec]
INFO - 12:47:33: 61%|██████ | 61/100 [00:00<00:00, 3439.16 it/sec]
INFO - 12:47:33: 62%|██████▏ | 62/100 [00:00<00:00, 3439.77 it/sec]
INFO - 12:47:33: 63%|██████▎ | 63/100 [00:00<00:00, 3441.76 it/sec]
INFO - 12:47:33: 64%|██████▍ | 64/100 [00:00<00:00, 3441.30 it/sec]
INFO - 12:47:33: 65%|██████▌ | 65/100 [00:00<00:00, 3441.64 it/sec]
INFO - 12:47:33: 66%|██████▌ | 66/100 [00:00<00:00, 3443.34 it/sec]
INFO - 12:47:33: 67%|██████▋ | 67/100 [00:00<00:00, 3442.54 it/sec]
INFO - 12:47:33: 68%|██████▊ | 68/100 [00:00<00:00, 3444.26 it/sec]
INFO - 12:47:33: 69%|██████▉ | 69/100 [00:00<00:00, 3445.69 it/sec]
INFO - 12:47:33: 70%|███████ | 70/100 [00:00<00:00, 3448.17 it/sec]
INFO - 12:47:33: 71%|███████ | 71/100 [00:00<00:00, 3450.14 it/sec]
INFO - 12:47:33: 72%|███████▏ | 72/100 [00:00<00:00, 3452.50 it/sec]
INFO - 12:47:33: 73%|███████▎ | 73/100 [00:00<00:00, 3455.45 it/sec]
INFO - 12:47:33: 74%|███████▍ | 74/100 [00:00<00:00, 3458.37 it/sec]
INFO - 12:47:33: 75%|███████▌ | 75/100 [00:00<00:00, 3460.95 it/sec]
INFO - 12:47:33: 76%|███████▌ | 76/100 [00:00<00:00, 3458.77 it/sec]
INFO - 12:47:33: 77%|███████▋ | 77/100 [00:00<00:00, 3461.02 it/sec]
INFO - 12:47:33: 78%|███████▊ | 78/100 [00:00<00:00, 3463.17 it/sec]
INFO - 12:47:33: 79%|███████▉ | 79/100 [00:00<00:00, 3464.45 it/sec]
INFO - 12:47:33: 80%|████████ | 80/100 [00:00<00:00, 3466.33 it/sec]
INFO - 12:47:33: 81%|████████ | 81/100 [00:00<00:00, 3468.63 it/sec]
INFO - 12:47:33: 82%|████████▏ | 82/100 [00:00<00:00, 3467.21 it/sec]
INFO - 12:47:33: 83%|████████▎ | 83/100 [00:00<00:00, 3469.13 it/sec]
INFO - 12:47:33: 84%|████████▍ | 84/100 [00:00<00:00, 3470.50 it/sec]
INFO - 12:47:33: 85%|████████▌ | 85/100 [00:00<00:00, 3472.00 it/sec]
INFO - 12:47:33: 86%|████████▌ | 86/100 [00:00<00:00, 3473.48 it/sec]
INFO - 12:47:33: 87%|████████▋ | 87/100 [00:00<00:00, 3475.11 it/sec]
INFO - 12:47:33: 88%|████████▊ | 88/100 [00:00<00:00, 3476.36 it/sec]
INFO - 12:47:33: 89%|████████▉ | 89/100 [00:00<00:00, 3477.99 it/sec]
INFO - 12:47:33: 90%|█████████ | 90/100 [00:00<00:00, 3473.93 it/sec]
INFO - 12:47:33: 91%|█████████ | 91/100 [00:00<00:00, 3472.67 it/sec]
INFO - 12:47:33: 92%|█████████▏| 92/100 [00:00<00:00, 3472.39 it/sec]
INFO - 12:47:33: 93%|█████████▎| 93/100 [00:00<00:00, 3472.45 it/sec]
INFO - 12:47:33: 94%|█████████▍| 94/100 [00:00<00:00, 3473.54 it/sec]
INFO - 12:47:33: 95%|█████████▌| 95/100 [00:00<00:00, 3475.38 it/sec]
INFO - 12:47:33: 96%|█████████▌| 96/100 [00:00<00:00, 3473.90 it/sec]
INFO - 12:47:33: 97%|█████████▋| 97/100 [00:00<00:00, 3475.10 it/sec]
INFO - 12:47:33: 98%|█████████▊| 98/100 [00:00<00:00, 3475.98 it/sec]
INFO - 12:47:33: 99%|█████████▉| 99/100 [00:00<00:00, 3477.43 it/sec]
INFO - 12:47:33: 100%|██████████| 100/100 [00:00<00:00, 3478.87 it/sec]
INFO - 12:47:33: *** End Sampling execution (time: 0:00:00.030153) ***
Settings#
The GaussianProcessRegressor
has many options
defined in the GaussianProcessRegressor_Settings
Pydantic model.
Here are the main ones.
Kernel#
The kernel
option defines the kernel function
parametrizing the Gaussian process regressor
and must be passed as a scikit-learn object.
The default kernel is the Matérn 5/2 covariance function
with input length scales belonging to the interval \([0.01,100]\),
initialized at 1
and optimized by the L-BFGS-B algorithm.
We can replace this kernel by the Matérn 5/2 kernel
with input length scales fixed at 1:
model = create_regression_model(
"GaussianProcessRegressor",
training_dataset,
kernel=Matern(length_scale=1.0, length_scale_bounds="fixed", nu=2.5),
)
model.learn()
predicted_output_data_1 = model.predict(input_data).ravel()
or a squared exponential covariance kernel with input length scales fixed at 1:
model = create_regression_model(
"GaussianProcessRegressor",
training_dataset,
kernel=RBF(length_scale=1.0, length_scale_bounds="fixed"),
)
model.learn()
predicted_output_data_2 = model.predict(input_data).ravel()
These two models are much better than the previous one, notably the one with the Matérn 5/2 kernel, which highlights that the concern with the initial model is the value of the length scales found by numerical optimization:
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(
input_data.ravel(), predicted_output_data_1, label="Regression - Kernel(Matern 2.5)"
)
plt.plot(input_data.ravel(), predicted_output_data_2, label="Regression - Kernel(RBF)")
plt.grid()
plt.legend()
plt.show()

Bounds#
The bounds
option defines the bounds of the input length scales;
model = create_regression_model(
"GaussianProcessRegressor", training_dataset, bounds=(1e-1, 1e2)
)
model.learn()
Increasing the lower bounds can facilitate the training as in this example:
predicted_output_data_ = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - Bounds")
plt.grid()
plt.legend()
plt.show()

Alpha#
The alpha
parameter (default: 1e-10),
often called nugget effect,
is the value added to the diagonal of the training kernel matrix
to avoid overfitting.
When alpha
is equal to zero,
the GP model interpolates the training points
at which the standard deviation is equal to zero.
The larger alpha
is, the less interpolating the GP model is.
For example, we can increase the value to 0.1:
predicted_output_data_1 = predicted_output_data_
model = create_regression_model(
"GaussianProcessRegressor", training_dataset, bounds=(1e-1, 1e2), alpha=0.1
)
model.learn()
and see that the model moves away from the training points:
predicted_output_data_2 = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data_1, label="Regression - Alpha(1e-10)")
plt.plot(input_data.ravel(), predicted_output_data_2, label="Regression - Alpha(1e-1)")
plt.grid()
plt.legend()
plt.show()

Total running time of the script: (0 minutes 0.514 seconds)