.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/mlearning/regression_model/plot_gp_regression.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_mlearning_regression_model_plot_gp_regression.py: Gaussian process (GP) regression ================================ A :class:`.GaussianProcessRegressor` is a GP regression model based on `scikit-learn `__. .. seealso:: You can find more information about building GP models with scikit-learn on `this page `__. .. GENERATED FROM PYTHON SOURCE LINES 32-48 .. code-block:: Python from __future__ import annotations from matplotlib import pyplot as plt from numpy import array from sklearn.gaussian_process.kernels import RBF from sklearn.gaussian_process.kernels import Matern from gemseo import configure_logger from gemseo import create_design_space from gemseo import create_discipline from gemseo import sample_disciplines from gemseo.mlearning import create_regression_model configure_logger() .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 49-54 Problem ------- In this example, we represent the function :math:`f(x)=(6x-2)^2\sin(12x-4)` :cite:`forrester2008` by the :class:`.AnalyticDiscipline` .. GENERATED FROM PYTHON SOURCE LINES 54-59 .. code-block:: Python discipline = create_discipline( "AnalyticDiscipline", name="f", expressions={"y": "(6*x-2)**2*sin(12*x-4)"}, ) .. GENERATED FROM PYTHON SOURCE LINES 60-61 and seek to approximate it over the input space .. GENERATED FROM PYTHON SOURCE LINES 61-64 .. code-block:: Python input_space = create_design_space() input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0) .. GENERATED FROM PYTHON SOURCE LINES 65-67 To do this, we create a training dataset with 6 equispaced points: .. GENERATED FROM PYTHON SOURCE LINES 67-71 .. code-block:: Python training_dataset = sample_disciplines( [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6 ) .. rst-class:: sphx-glr-script-out .. code-block:: none WARNING - 20:35:07: No coupling in MDA, switching chain_linearize to True. INFO - 20:35:07: *** Start Sampling execution *** INFO - 20:35:07: Sampling INFO - 20:35:07: Disciplines: f INFO - 20:35:07: MDO formulation: MDF INFO - 20:35:07: Running the algorithm PYDOE_FULLFACT: INFO - 20:35:07: 17%|█▋ | 1/6 [00:00<00:00, 550.94 it/sec] INFO - 20:35:07: 33%|███▎ | 2/6 [00:00<00:00, 899.00 it/sec] INFO - 20:35:07: 50%|█████ | 3/6 [00:00<00:00, 1204.68 it/sec] INFO - 20:35:07: 67%|██████▋ | 4/6 [00:00<00:00, 1468.21 it/sec] INFO - 20:35:07: 83%|████████▎ | 5/6 [00:00<00:00, 1699.34 it/sec] INFO - 20:35:07: 100%|██████████| 6/6 [00:00<00:00, 1878.89 it/sec] INFO - 20:35:07: *** End Sampling execution *** .. GENERATED FROM PYTHON SOURCE LINES 72-78 Basics ------ Training ~~~~~~~~ Then, we train a GP regression model from these samples: .. GENERATED FROM PYTHON SOURCE LINES 78-81 .. code-block:: Python model = create_regression_model("GaussianProcessRegressor", training_dataset) model.learn() .. GENERATED FROM PYTHON SOURCE LINES 82-86 Prediction ~~~~~~~~~~ Once it is built, we can predict the output value of :math:`f` at a new input point: .. GENERATED FROM PYTHON SOURCE LINES 86-90 .. code-block:: Python input_value = {"x": array([0.65])} output_value = model.predict(input_value) output_value .. rst-class:: sphx-glr-script-out .. code-block:: none {'y': array([2.20380214])} .. GENERATED FROM PYTHON SOURCE LINES 91-92 but cannot predict its Jacobian value: .. GENERATED FROM PYTHON SOURCE LINES 92-97 .. code-block:: Python try: model.predict_jacobian(input_value) except NotImplementedError: print("The derivatives are not available for GaussianProcessRegressor.") .. rst-class:: sphx-glr-script-out .. code-block:: none The derivatives are not available for GaussianProcessRegressor. .. GENERATED FROM PYTHON SOURCE LINES 98-108 Uncertainty ~~~~~~~~~~~ GP models are often valued for their ability to provide model uncertainty. Indeed, a GP model is a random process fully characterized by its mean function and a covariance structure. Given an input point :math:`x`, the prediction is equal to the mean at :math:`x` and the uncertainty is equal to the standard deviation at :math:`x`: .. GENERATED FROM PYTHON SOURCE LINES 108-111 .. code-block:: Python standard_deviation = model.predict_std(input_value) standard_deviation .. rst-class:: sphx-glr-script-out .. code-block:: none array([[0.3140468]]) .. GENERATED FROM PYTHON SOURCE LINES 112-118 Plotting ~~~~~~~~ You can see that the GP model interpolates the training points but is very bad elsewhere. This case-dependent problem is due to poor auto-tuning of these length scales. We will look at how to correct this next. .. GENERATED FROM PYTHON SOURCE LINES 118-130 .. code-block:: Python test_dataset = sample_disciplines( [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100 ) input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy() reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel() predicted_output_data = model.predict(input_data).ravel() plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_gp_regression_001.png :alt: plot gp regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_gp_regression_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none WARNING - 20:35:07: No coupling in MDA, switching chain_linearize to True. INFO - 20:35:07: *** Start Sampling execution *** INFO - 20:35:07: Sampling INFO - 20:35:07: Disciplines: f INFO - 20:35:07: MDO formulation: MDF INFO - 20:35:07: Running the algorithm PYDOE_FULLFACT: INFO - 20:35:07: 1%| | 1/100 [00:00<00:00, 3492.34 it/sec] INFO - 20:35:07: 2%|▏ | 2/100 [00:00<00:00, 3633.00 it/sec] INFO - 20:35:07: 3%|▎ | 3/100 [00:00<00:00, 3855.06 it/sec] INFO - 20:35:07: 4%|▍ | 4/100 [00:00<00:00, 4025.24 it/sec] INFO - 20:35:07: 5%|▌ | 5/100 [00:00<00:00, 4095.20 it/sec] INFO - 20:35:07: 6%|▌ | 6/100 [00:00<00:00, 4155.52 it/sec] INFO - 20:35:07: 7%|▋ | 7/100 [00:00<00:00, 4235.45 it/sec] INFO - 20:35:07: 8%|▊ | 8/100 [00:00<00:00, 4301.85 it/sec] INFO - 20:35:07: 9%|▉ | 9/100 [00:00<00:00, 4311.68 it/sec] INFO - 20:35:07: 10%|█ | 10/100 [00:00<00:00, 4301.41 it/sec] INFO - 20:35:07: 11%|█ | 11/100 [00:00<00:00, 4333.37 it/sec] INFO - 20:35:07: 12%|█▏ | 12/100 [00:00<00:00, 4367.17 it/sec] INFO - 20:35:07: 13%|█▎ | 13/100 [00:00<00:00, 4402.58 it/sec] INFO - 20:35:07: 14%|█▍ | 14/100 [00:00<00:00, 4442.45 it/sec] INFO - 20:35:07: 15%|█▌ | 15/100 [00:00<00:00, 4423.44 it/sec] INFO - 20:35:07: 16%|█▌ | 16/100 [00:00<00:00, 4453.73 it/sec] INFO - 20:35:07: 17%|█▋ | 17/100 [00:00<00:00, 4488.43 it/sec] INFO - 20:35:07: 18%|█▊ | 18/100 [00:00<00:00, 4515.40 it/sec] INFO - 20:35:07: 19%|█▉ | 19/100 [00:00<00:00, 4536.96 it/sec] INFO - 20:35:07: 20%|██ | 20/100 [00:00<00:00, 4524.84 it/sec] INFO - 20:35:07: 21%|██ | 21/100 [00:00<00:00, 4538.82 it/sec] INFO - 20:35:07: 22%|██▏ | 22/100 [00:00<00:00, 4558.35 it/sec] INFO - 20:35:07: 23%|██▎ | 23/100 [00:00<00:00, 4579.15 it/sec] INFO - 20:35:07: 24%|██▍ | 24/100 [00:00<00:00, 4572.69 it/sec] INFO - 20:35:07: 25%|██▌ | 25/100 [00:00<00:00, 4582.34 it/sec] INFO - 20:35:07: 26%|██▌ | 26/100 [00:00<00:00, 4597.47 it/sec] INFO - 20:35:07: 27%|██▋ | 27/100 [00:00<00:00, 4613.07 it/sec] INFO - 20:35:07: 28%|██▊ | 28/100 [00:00<00:00, 4630.02 it/sec] INFO - 20:35:07: 29%|██▉ | 29/100 [00:00<00:00, 4632.65 it/sec] INFO - 20:35:07: 30%|███ | 30/100 [00:00<00:00, 4642.80 it/sec] INFO - 20:35:07: 31%|███ | 31/100 [00:00<00:00, 4657.17 it/sec] INFO - 20:35:07: 32%|███▏ | 32/100 [00:00<00:00, 4670.23 it/sec] INFO - 20:35:07: 33%|███▎ | 33/100 [00:00<00:00, 4682.25 it/sec] INFO - 20:35:07: 34%|███▍ | 34/100 [00:00<00:00, 4681.91 it/sec] INFO - 20:35:07: 35%|███▌ | 35/100 [00:00<00:00, 4688.32 it/sec] INFO - 20:35:07: 36%|███▌ | 36/100 [00:00<00:00, 4695.85 it/sec] INFO - 20:35:07: 37%|███▋ | 37/100 [00:00<00:00, 4696.02 it/sec] INFO - 20:35:07: 38%|███▊ | 38/100 [00:00<00:00, 4706.72 it/sec] INFO - 20:35:07: 39%|███▉ | 39/100 [00:00<00:00, 4707.55 it/sec] INFO - 20:35:07: 40%|████ | 40/100 [00:00<00:00, 4713.50 it/sec] INFO - 20:35:07: 41%|████ | 41/100 [00:00<00:00, 4657.06 it/sec] INFO - 20:35:07: 42%|████▏ | 42/100 [00:00<00:00, 4659.35 it/sec] INFO - 20:35:07: 43%|████▎ | 43/100 [00:00<00:00, 4657.45 it/sec] INFO - 20:35:07: 44%|████▍ | 44/100 [00:00<00:00, 4661.87 it/sec] INFO - 20:35:07: 45%|████▌ | 45/100 [00:00<00:00, 4666.91 it/sec] INFO - 20:35:07: 46%|████▌ | 46/100 [00:00<00:00, 4674.79 it/sec] INFO - 20:35:07: 47%|████▋ | 47/100 [00:00<00:00, 4679.92 it/sec] INFO - 20:35:07: 48%|████▊ | 48/100 [00:00<00:00, 4678.75 it/sec] INFO - 20:35:07: 49%|████▉ | 49/100 [00:00<00:00, 4683.81 it/sec] INFO - 20:35:07: 50%|█████ | 50/100 [00:00<00:00, 4691.41 it/sec] INFO - 20:35:07: 51%|█████ | 51/100 [00:00<00:00, 4698.52 it/sec] INFO - 20:35:07: 52%|█████▏ | 52/100 [00:00<00:00, 4705.89 it/sec] INFO - 20:35:07: 53%|█████▎ | 53/100 [00:00<00:00, 4701.34 it/sec] INFO - 20:35:07: 54%|█████▍ | 54/100 [00:00<00:00, 4705.94 it/sec] INFO - 20:35:07: 55%|█████▌ | 55/100 [00:00<00:00, 4713.18 it/sec] INFO - 20:35:07: 56%|█████▌ | 56/100 [00:00<00:00, 4721.32 it/sec] INFO - 20:35:07: 57%|█████▋ | 57/100 [00:00<00:00, 4727.42 it/sec] INFO - 20:35:07: 58%|█████▊ | 58/100 [00:00<00:00, 4724.60 it/sec] INFO - 20:35:07: 59%|█████▉ | 59/100 [00:00<00:00, 4728.46 it/sec] INFO - 20:35:07: 60%|██████ | 60/100 [00:00<00:00, 4734.87 it/sec] INFO - 20:35:07: 61%|██████ | 61/100 [00:00<00:00, 4741.00 it/sec] INFO - 20:35:07: 62%|██████▏ | 62/100 [00:00<00:00, 4747.03 it/sec] INFO - 20:35:07: 63%|██████▎ | 63/100 [00:00<00:00, 4746.99 it/sec] INFO - 20:35:07: 64%|██████▍ | 64/100 [00:00<00:00, 4750.31 it/sec] INFO - 20:35:07: 65%|██████▌ | 65/100 [00:00<00:00, 4755.28 it/sec] INFO - 20:35:07: 66%|██████▌ | 66/100 [00:00<00:00, 4760.44 it/sec] INFO - 20:35:07: 67%|██████▋ | 67/100 [00:00<00:00, 4765.37 it/sec] INFO - 20:35:07: 68%|██████▊ | 68/100 [00:00<00:00, 4764.34 it/sec] INFO - 20:35:07: 69%|██████▉ | 69/100 [00:00<00:00, 4767.75 it/sec] INFO - 20:35:07: 70%|███████ | 70/100 [00:00<00:00, 4772.45 it/sec] INFO - 20:35:07: 71%|███████ | 71/100 [00:00<00:00, 4777.11 it/sec] INFO - 20:35:07: 72%|███████▏ | 72/100 [00:00<00:00, 4781.65 it/sec] INFO - 20:35:07: 73%|███████▎ | 73/100 [00:00<00:00, 4778.83 it/sec] INFO - 20:35:07: 74%|███████▍ | 74/100 [00:00<00:00, 4780.86 it/sec] INFO - 20:35:07: 75%|███████▌ | 75/100 [00:00<00:00, 4785.40 it/sec] INFO - 20:35:07: 76%|███████▌ | 76/100 [00:00<00:00, 4788.95 it/sec] INFO - 20:35:07: 77%|███████▋ | 77/100 [00:00<00:00, 4791.78 it/sec] INFO - 20:35:07: 78%|███████▊ | 78/100 [00:00<00:00, 4790.26 it/sec] INFO - 20:35:07: 79%|███████▉ | 79/100 [00:00<00:00, 4791.83 it/sec] INFO - 20:35:07: 80%|████████ | 80/100 [00:00<00:00, 4795.96 it/sec] INFO - 20:35:07: 81%|████████ | 81/100 [00:00<00:00, 4799.72 it/sec] INFO - 20:35:07: 82%|████████▏ | 82/100 [00:00<00:00, 4803.53 it/sec] INFO - 20:35:07: 83%|████████▎ | 83/100 [00:00<00:00, 4802.48 it/sec] INFO - 20:35:07: 84%|████████▍ | 84/100 [00:00<00:00, 4804.08 it/sec] INFO - 20:35:07: 85%|████████▌ | 85/100 [00:00<00:00, 4806.03 it/sec] INFO - 20:35:07: 86%|████████▌ | 86/100 [00:00<00:00, 4804.09 it/sec] INFO - 20:35:07: 87%|████████▋ | 87/100 [00:00<00:00, 4805.36 it/sec] INFO - 20:35:07: 88%|████████▊ | 88/100 [00:00<00:00, 4802.85 it/sec] INFO - 20:35:07: 89%|████████▉ | 89/100 [00:00<00:00, 4804.90 it/sec] INFO - 20:35:07: 90%|█████████ | 90/100 [00:00<00:00, 4807.47 it/sec] INFO - 20:35:07: 91%|█████████ | 91/100 [00:00<00:00, 4810.04 it/sec] INFO - 20:35:07: 92%|█████████▏| 92/100 [00:00<00:00, 4812.68 it/sec] INFO - 20:35:07: 93%|█████████▎| 93/100 [00:00<00:00, 4807.43 it/sec] INFO - 20:35:07: 94%|█████████▍| 94/100 [00:00<00:00, 4807.46 it/sec] INFO - 20:35:07: 95%|█████████▌| 95/100 [00:00<00:00, 4810.16 it/sec] INFO - 20:35:07: 96%|█████████▌| 96/100 [00:00<00:00, 4813.14 it/sec] INFO - 20:35:07: 97%|█████████▋| 97/100 [00:00<00:00, 4816.70 it/sec] INFO - 20:35:07: 98%|█████████▊| 98/100 [00:00<00:00, 4814.55 it/sec] INFO - 20:35:07: 99%|█████████▉| 99/100 [00:00<00:00, 4817.35 it/sec] INFO - 20:35:07: 100%|██████████| 100/100 [00:00<00:00, 4820.71 it/sec] INFO - 20:35:07: *** End Sampling execution *** .. GENERATED FROM PYTHON SOURCE LINES 131-148 Settings -------- The :class:`.GaussianProcessRegressor` has many options defined in the :class:`.GaussianProcessRegressor_Settings` Pydantic model. Here are the main ones. Kernel ~~~~~~ The ``kernel`` option defines the kernel function parametrizing the Gaussian process regressor and must be passed as a scikit-learn object. The default kernel is the Matérn 5/2 covariance function with input length scales belonging to the interval :math:`[0.01,100]`, initialized at 1 and optimized by the L-BFGS-B algorithm. We can replace this kernel by the Matérn 5/2 kernel with input length scales fixed at 1: .. GENERATED FROM PYTHON SOURCE LINES 148-155 .. code-block:: Python model = create_regression_model( "GaussianProcessRegressor", training_dataset, kernel=Matern(length_scale=1.0, length_scale_bounds="fixed", nu=2.5), ) model.learn() predicted_output_data_1 = model.predict(input_data).ravel() .. GENERATED FROM PYTHON SOURCE LINES 156-158 or a squared exponential covariance kernel with input length scales fixed at 1: .. GENERATED FROM PYTHON SOURCE LINES 158-165 .. code-block:: Python model = create_regression_model( "GaussianProcessRegressor", training_dataset, kernel=RBF(length_scale=1.0, length_scale_bounds="fixed"), ) model.learn() predicted_output_data_2 = model.predict(input_data).ravel() .. GENERATED FROM PYTHON SOURCE LINES 166-170 These two models are much better than the previous one, notably the one with the Matérn 5/2 kernel, which highlights that the concern with the initial model is the value of the length scales found by numerical optimization: .. GENERATED FROM PYTHON SOURCE LINES 170-179 .. code-block:: Python plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.plot( input_data.ravel(), predicted_output_data_1, label="Regression - Kernel(Matern 2.5)" ) plt.plot(input_data.ravel(), predicted_output_data_2, label="Regression - Kernel(RBF)") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_gp_regression_002.png :alt: plot gp regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_gp_regression_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 180-183 Bounds ~~~~~~ The ``bounds`` option defines the bounds of the input length scales; .. GENERATED FROM PYTHON SOURCE LINES 183-187 .. code-block:: Python model = create_regression_model( "GaussianProcessRegressor", training_dataset, bounds=(1e-1, 1e2) ) model.learn() .. GENERATED FROM PYTHON SOURCE LINES 188-189 Increasing the lower bounds can facilitate the training as in this example: .. GENERATED FROM PYTHON SOURCE LINES 189-197 .. code-block:: Python predicted_output_data_ = model.predict(input_data).ravel() plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - Bounds") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_gp_regression_003.png :alt: plot gp regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_gp_regression_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 198-209 Alpha ~~~~~ The ``alpha`` parameter (default: 1e-10), often called *nugget effect*, is the value added to the diagonal of the training kernel matrix to avoid overfitting. When ``alpha`` is equal to zero, the GP model interpolates the training points at which the standard deviation is equal to zero. The larger ``alpha`` is, the less interpolating the GP model is. For example, we can increase the value to 0.1: .. GENERATED FROM PYTHON SOURCE LINES 209-214 .. code-block:: Python predicted_output_data_1 = predicted_output_data_ model = create_regression_model( "GaussianProcessRegressor", training_dataset, bounds=(1e-1, 1e2), alpha=0.1 ) model.learn() .. GENERATED FROM PYTHON SOURCE LINES 215-216 and see that the model moves away from the training points: .. GENERATED FROM PYTHON SOURCE LINES 216-223 .. code-block:: Python predicted_output_data_2 = model.predict(input_data).ravel() plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data_1, label="Regression - Alpha(1e-10)") plt.plot(input_data.ravel(), predicted_output_data_2, label="Regression - Alpha(1e-1)") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_gp_regression_004.png :alt: plot gp regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_gp_regression_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.311 seconds) .. _sphx_glr_download_examples_mlearning_regression_model_plot_gp_regression.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_gp_regression.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_gp_regression.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_gp_regression.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_