Gaussian process (GP) regression#

A GaussianProcessRegressor is a GP regression model based on scikit-learn.

See also

You can find more information about building GP models with scikit-learn on this page.

from __future__ import annotations

from matplotlib import pyplot as plt
from numpy import array
from sklearn.gaussian_process.kernels import RBF
from sklearn.gaussian_process.kernels import Matern

from gemseo import configure_logger
from gemseo import create_design_space
from gemseo import create_discipline
from gemseo import sample_disciplines
from gemseo.mlearning import create_regression_model

configure_logger()
<RootLogger root (INFO)>

Problem#

In this example, we represent the function \(f(x)=(6x-2)^2\sin(12x-4)\) [FSK08] by the AnalyticDiscipline

discipline = create_discipline(
    "AnalyticDiscipline",
    name="f",
    expressions={"y": "(6*x-2)**2*sin(12*x-4)"},
)

and seek to approximate it over the input space

input_space = create_design_space()
input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0)

To do this, we create a training dataset with 6 equispaced points:

training_dataset = sample_disciplines(
    [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6
)
WARNING - 12:47:33: No coupling in MDA, switching chain_linearize to True.
   INFO - 12:47:33:
   INFO - 12:47:33: *** Start Sampling execution ***
   INFO - 12:47:33: Sampling
   INFO - 12:47:33:    Disciplines: f
   INFO - 12:47:33:    MDO formulation: MDF
   INFO - 12:47:33: Running the algorithm PYDOE_FULLFACT:
   INFO - 12:47:33:     17%|█▋        | 1/6 [00:00<00:00, 522.72 it/sec]
   INFO - 12:47:33:     33%|███▎      | 2/6 [00:00<00:00, 855.98 it/sec]
   INFO - 12:47:33:     50%|█████     | 3/6 [00:00<00:00, 1098.08 it/sec]
   INFO - 12:47:33:     67%|██████▋   | 4/6 [00:00<00:00, 1295.34 it/sec]
   INFO - 12:47:33:     83%|████████▎ | 5/6 [00:00<00:00, 1461.23 it/sec]
   INFO - 12:47:33:    100%|██████████| 6/6 [00:00<00:00, 1594.69 it/sec]
   INFO - 12:47:33: *** End Sampling execution (time: 0:00:00.005070) ***

Basics#

Training#

Then, we train a GP regression model from these samples:

model = create_regression_model("GaussianProcessRegressor", training_dataset)
model.learn()

Prediction#

Once it is built, we can predict the output value of \(f\) at a new input point:

input_value = {"x": array([0.65])}
output_value = model.predict(input_value)
output_value
{'y': array([2.20380214])}

but cannot predict its Jacobian value:

try:
    model.predict_jacobian(input_value)
except NotImplementedError:
    print("The derivatives are not available for GaussianProcessRegressor.")
The derivatives are not available for GaussianProcessRegressor.

Uncertainty#

GP models are often valued for their ability to provide model uncertainty. Indeed, a GP model is a random process fully characterized by its mean function and a covariance structure. Given an input point \(x\), the prediction is equal to the mean at \(x\) and the uncertainty is equal to the standard deviation at \(x\):

standard_deviation = model.predict_std(input_value)
standard_deviation
array([[0.3140468]])

Plotting#

You can see that the GP model interpolates the training points but is very bad elsewhere. This case-dependent problem is due to poor auto-tuning of these length scales. We will look at how to correct this next.

test_dataset = sample_disciplines(
    [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100
)
input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy()
reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel()
predicted_output_data = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.grid()
plt.legend()
plt.show()
plot gp regression
WARNING - 12:47:33: No coupling in MDA, switching chain_linearize to True.
   INFO - 12:47:33:
   INFO - 12:47:33: *** Start Sampling execution ***
   INFO - 12:47:33: Sampling
   INFO - 12:47:33:    Disciplines: f
   INFO - 12:47:33:    MDO formulation: MDF
   INFO - 12:47:33: Running the algorithm PYDOE_FULLFACT:
   INFO - 12:47:33:      1%|          | 1/100 [00:00<00:00, 3305.20 it/sec]
   INFO - 12:47:33:      2%|▏         | 2/100 [00:00<00:00, 3066.01 it/sec]
   INFO - 12:47:33:      3%|▎         | 3/100 [00:00<00:00, 3067.51 it/sec]
   INFO - 12:47:33:      4%|▍         | 4/100 [00:00<00:00, 3077.82 it/sec]
   INFO - 12:47:33:      5%|▌         | 5/100 [00:00<00:00, 3091.32 it/sec]
   INFO - 12:47:33:      6%|▌         | 6/100 [00:00<00:00, 3037.88 it/sec]
   INFO - 12:47:33:      7%|▋         | 7/100 [00:00<00:00, 3067.93 it/sec]
   INFO - 12:47:33:      8%|▊         | 8/100 [00:00<00:00, 3106.89 it/sec]
   INFO - 12:47:33:      9%|▉         | 9/100 [00:00<00:00, 3145.73 it/sec]
   INFO - 12:47:33:     10%|█         | 10/100 [00:00<00:00, 3183.53 it/sec]
   INFO - 12:47:33:     11%|█         | 11/100 [00:00<00:00, 3191.35 it/sec]
   INFO - 12:47:33:     12%|█▏        | 12/100 [00:00<00:00, 3204.20 it/sec]
   INFO - 12:47:33:     13%|█▎        | 13/100 [00:00<00:00, 3226.01 it/sec]
   INFO - 12:47:33:     14%|█▍        | 14/100 [00:00<00:00, 3248.34 it/sec]
   INFO - 12:47:33:     15%|█▌        | 15/100 [00:00<00:00, 3268.29 it/sec]
   INFO - 12:47:33:     16%|█▌        | 16/100 [00:00<00:00, 3287.56 it/sec]
   INFO - 12:47:33:     17%|█▋        | 17/100 [00:00<00:00, 3306.58 it/sec]
   INFO - 12:47:33:     18%|█▊        | 18/100 [00:00<00:00, 3319.74 it/sec]
   INFO - 12:47:33:     19%|█▉        | 19/100 [00:00<00:00, 3316.21 it/sec]
   INFO - 12:47:33:     20%|██        | 20/100 [00:00<00:00, 3320.77 it/sec]
   INFO - 12:47:33:     21%|██        | 21/100 [00:00<00:00, 3330.95 it/sec]
   INFO - 12:47:33:     22%|██▏       | 22/100 [00:00<00:00, 3341.95 it/sec]
   INFO - 12:47:33:     23%|██▎       | 23/100 [00:00<00:00, 3353.11 it/sec]
   INFO - 12:47:33:     24%|██▍       | 24/100 [00:00<00:00, 3364.64 it/sec]
   INFO - 12:47:33:     25%|██▌       | 25/100 [00:00<00:00, 3364.92 it/sec]
   INFO - 12:47:33:     26%|██▌       | 26/100 [00:00<00:00, 3367.25 it/sec]
   INFO - 12:47:33:     27%|██▋       | 27/100 [00:00<00:00, 3375.45 it/sec]
   INFO - 12:47:33:     28%|██▊       | 28/100 [00:00<00:00, 3383.48 it/sec]
   INFO - 12:47:33:     29%|██▉       | 29/100 [00:00<00:00, 3391.84 it/sec]
   INFO - 12:47:33:     30%|███       | 30/100 [00:00<00:00, 3400.33 it/sec]
   INFO - 12:47:33:     31%|███       | 31/100 [00:00<00:00, 3407.59 it/sec]
   INFO - 12:47:33:     32%|███▏      | 32/100 [00:00<00:00, 3415.04 it/sec]
   INFO - 12:47:33:     33%|███▎      | 33/100 [00:00<00:00, 3422.74 it/sec]
   INFO - 12:47:33:     34%|███▍      | 34/100 [00:00<00:00, 3407.97 it/sec]
   INFO - 12:47:33:     35%|███▌      | 35/100 [00:00<00:00, 3409.05 it/sec]
   INFO - 12:47:33:     36%|███▌      | 36/100 [00:00<00:00, 3412.32 it/sec]
   INFO - 12:47:33:     37%|███▋      | 37/100 [00:00<00:00, 3417.14 it/sec]
   INFO - 12:47:33:     38%|███▊      | 38/100 [00:00<00:00, 3421.86 it/sec]
   INFO - 12:47:33:     39%|███▉      | 39/100 [00:00<00:00, 3422.20 it/sec]
   INFO - 12:47:33:     40%|████      | 40/100 [00:00<00:00, 3424.13 it/sec]
   INFO - 12:47:33:     41%|████      | 41/100 [00:00<00:00, 3429.04 it/sec]
   INFO - 12:47:33:     42%|████▏     | 42/100 [00:00<00:00, 3430.72 it/sec]
   INFO - 12:47:33:     43%|████▎     | 43/100 [00:00<00:00, 3432.78 it/sec]
   INFO - 12:47:33:     44%|████▍     | 44/100 [00:00<00:00, 3434.82 it/sec]
   INFO - 12:47:33:     45%|████▌     | 45/100 [00:00<00:00, 3438.27 it/sec]
   INFO - 12:47:33:     46%|████▌     | 46/100 [00:00<00:00, 3442.43 it/sec]
   INFO - 12:47:33:     47%|████▋     | 47/100 [00:00<00:00, 3446.25 it/sec]
   INFO - 12:47:33:     48%|████▊     | 48/100 [00:00<00:00, 3441.72 it/sec]
   INFO - 12:47:33:     49%|████▉     | 49/100 [00:00<00:00, 3444.35 it/sec]
   INFO - 12:47:33:     50%|█████     | 50/100 [00:00<00:00, 3448.75 it/sec]
   INFO - 12:47:33:     51%|█████     | 51/100 [00:00<00:00, 3450.49 it/sec]
   INFO - 12:47:33:     52%|█████▏    | 52/100 [00:00<00:00, 3452.98 it/sec]
   INFO - 12:47:33:     53%|█████▎    | 53/100 [00:00<00:00, 3455.80 it/sec]
   INFO - 12:47:33:     54%|█████▍    | 54/100 [00:00<00:00, 3454.42 it/sec]
   INFO - 12:47:33:     55%|█████▌    | 55/100 [00:00<00:00, 3457.12 it/sec]
   INFO - 12:47:33:     56%|█████▌    | 56/100 [00:00<00:00, 3459.58 it/sec]
   INFO - 12:47:33:     57%|█████▋    | 57/100 [00:00<00:00, 3462.90 it/sec]
   INFO - 12:47:33:     58%|█████▊    | 58/100 [00:00<00:00, 3466.66 it/sec]
   INFO - 12:47:33:     59%|█████▉    | 59/100 [00:00<00:00, 3443.02 it/sec]
   INFO - 12:47:33:     60%|██████    | 60/100 [00:00<00:00, 3442.33 it/sec]
   INFO - 12:47:33:     61%|██████    | 61/100 [00:00<00:00, 3439.16 it/sec]
   INFO - 12:47:33:     62%|██████▏   | 62/100 [00:00<00:00, 3439.77 it/sec]
   INFO - 12:47:33:     63%|██████▎   | 63/100 [00:00<00:00, 3441.76 it/sec]
   INFO - 12:47:33:     64%|██████▍   | 64/100 [00:00<00:00, 3441.30 it/sec]
   INFO - 12:47:33:     65%|██████▌   | 65/100 [00:00<00:00, 3441.64 it/sec]
   INFO - 12:47:33:     66%|██████▌   | 66/100 [00:00<00:00, 3443.34 it/sec]
   INFO - 12:47:33:     67%|██████▋   | 67/100 [00:00<00:00, 3442.54 it/sec]
   INFO - 12:47:33:     68%|██████▊   | 68/100 [00:00<00:00, 3444.26 it/sec]
   INFO - 12:47:33:     69%|██████▉   | 69/100 [00:00<00:00, 3445.69 it/sec]
   INFO - 12:47:33:     70%|███████   | 70/100 [00:00<00:00, 3448.17 it/sec]
   INFO - 12:47:33:     71%|███████   | 71/100 [00:00<00:00, 3450.14 it/sec]
   INFO - 12:47:33:     72%|███████▏  | 72/100 [00:00<00:00, 3452.50 it/sec]
   INFO - 12:47:33:     73%|███████▎  | 73/100 [00:00<00:00, 3455.45 it/sec]
   INFO - 12:47:33:     74%|███████▍  | 74/100 [00:00<00:00, 3458.37 it/sec]
   INFO - 12:47:33:     75%|███████▌  | 75/100 [00:00<00:00, 3460.95 it/sec]
   INFO - 12:47:33:     76%|███████▌  | 76/100 [00:00<00:00, 3458.77 it/sec]
   INFO - 12:47:33:     77%|███████▋  | 77/100 [00:00<00:00, 3461.02 it/sec]
   INFO - 12:47:33:     78%|███████▊  | 78/100 [00:00<00:00, 3463.17 it/sec]
   INFO - 12:47:33:     79%|███████▉  | 79/100 [00:00<00:00, 3464.45 it/sec]
   INFO - 12:47:33:     80%|████████  | 80/100 [00:00<00:00, 3466.33 it/sec]
   INFO - 12:47:33:     81%|████████  | 81/100 [00:00<00:00, 3468.63 it/sec]
   INFO - 12:47:33:     82%|████████▏ | 82/100 [00:00<00:00, 3467.21 it/sec]
   INFO - 12:47:33:     83%|████████▎ | 83/100 [00:00<00:00, 3469.13 it/sec]
   INFO - 12:47:33:     84%|████████▍ | 84/100 [00:00<00:00, 3470.50 it/sec]
   INFO - 12:47:33:     85%|████████▌ | 85/100 [00:00<00:00, 3472.00 it/sec]
   INFO - 12:47:33:     86%|████████▌ | 86/100 [00:00<00:00, 3473.48 it/sec]
   INFO - 12:47:33:     87%|████████▋ | 87/100 [00:00<00:00, 3475.11 it/sec]
   INFO - 12:47:33:     88%|████████▊ | 88/100 [00:00<00:00, 3476.36 it/sec]
   INFO - 12:47:33:     89%|████████▉ | 89/100 [00:00<00:00, 3477.99 it/sec]
   INFO - 12:47:33:     90%|█████████ | 90/100 [00:00<00:00, 3473.93 it/sec]
   INFO - 12:47:33:     91%|█████████ | 91/100 [00:00<00:00, 3472.67 it/sec]
   INFO - 12:47:33:     92%|█████████▏| 92/100 [00:00<00:00, 3472.39 it/sec]
   INFO - 12:47:33:     93%|█████████▎| 93/100 [00:00<00:00, 3472.45 it/sec]
   INFO - 12:47:33:     94%|█████████▍| 94/100 [00:00<00:00, 3473.54 it/sec]
   INFO - 12:47:33:     95%|█████████▌| 95/100 [00:00<00:00, 3475.38 it/sec]
   INFO - 12:47:33:     96%|█████████▌| 96/100 [00:00<00:00, 3473.90 it/sec]
   INFO - 12:47:33:     97%|█████████▋| 97/100 [00:00<00:00, 3475.10 it/sec]
   INFO - 12:47:33:     98%|█████████▊| 98/100 [00:00<00:00, 3475.98 it/sec]
   INFO - 12:47:33:     99%|█████████▉| 99/100 [00:00<00:00, 3477.43 it/sec]
   INFO - 12:47:33:    100%|██████████| 100/100 [00:00<00:00, 3478.87 it/sec]
   INFO - 12:47:33: *** End Sampling execution (time: 0:00:00.030153) ***

Settings#

The GaussianProcessRegressor has many options defined in the GaussianProcessRegressor_Settings Pydantic model. Here are the main ones.

Kernel#

The kernel option defines the kernel function parametrizing the Gaussian process regressor and must be passed as a scikit-learn object. The default kernel is the Matérn 5/2 covariance function with input length scales belonging to the interval \([0.01,100]\), initialized at 1 and optimized by the L-BFGS-B algorithm. We can replace this kernel by the Matérn 5/2 kernel with input length scales fixed at 1:

model = create_regression_model(
    "GaussianProcessRegressor",
    training_dataset,
    kernel=Matern(length_scale=1.0, length_scale_bounds="fixed", nu=2.5),
)
model.learn()
predicted_output_data_1 = model.predict(input_data).ravel()

or a squared exponential covariance kernel with input length scales fixed at 1:

model = create_regression_model(
    "GaussianProcessRegressor",
    training_dataset,
    kernel=RBF(length_scale=1.0, length_scale_bounds="fixed"),
)
model.learn()
predicted_output_data_2 = model.predict(input_data).ravel()

These two models are much better than the previous one, notably the one with the Matérn 5/2 kernel, which highlights that the concern with the initial model is the value of the length scales found by numerical optimization:

plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(
    input_data.ravel(), predicted_output_data_1, label="Regression - Kernel(Matern 2.5)"
)
plt.plot(input_data.ravel(), predicted_output_data_2, label="Regression - Kernel(RBF)")
plt.grid()
plt.legend()
plt.show()
plot gp regression

Bounds#

The bounds option defines the bounds of the input length scales;

model = create_regression_model(
    "GaussianProcessRegressor", training_dataset, bounds=(1e-1, 1e2)
)
model.learn()

Increasing the lower bounds can facilitate the training as in this example:

predicted_output_data_ = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - Bounds")
plt.grid()
plt.legend()
plt.show()
plot gp regression

Alpha#

The alpha parameter (default: 1e-10), often called nugget effect, is the value added to the diagonal of the training kernel matrix to avoid overfitting. When alpha is equal to zero, the GP model interpolates the training points at which the standard deviation is equal to zero. The larger alpha is, the less interpolating the GP model is. For example, we can increase the value to 0.1:

predicted_output_data_1 = predicted_output_data_
model = create_regression_model(
    "GaussianProcessRegressor", training_dataset, bounds=(1e-1, 1e2), alpha=0.1
)
model.learn()

and see that the model moves away from the training points:

predicted_output_data_2 = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data_1, label="Regression - Alpha(1e-10)")
plt.plot(input_data.ravel(), predicted_output_data_2, label="Regression - Alpha(1e-1)")
plt.grid()
plt.legend()
plt.show()
plot gp regression

Total running time of the script: (0 minutes 0.514 seconds)

Gallery generated by Sphinx-Gallery