.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/mlearning/regression_model/plot_rbf_regression.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_mlearning_regression_model_plot_rbf_regression.py: Radial basis function (RBF) regression ====================================== An :class:`.RBFRegressor` is an RBF model based on `SciPy `__. .. seealso:: You can find more information about RBF models on `this wikipedia page `__. .. GENERATED FROM PYTHON SOURCE LINES 32-44 .. code-block:: Python from __future__ import annotations from matplotlib import pyplot as plt from numpy import array from gemseo import create_design_space from gemseo import create_discipline from gemseo import sample_disciplines from gemseo.mlearning import create_regression_model from gemseo.mlearning.regression.algos.rbf_settings import RBF .. GENERATED FROM PYTHON SOURCE LINES 45-50 Problem ------- In this example, we represent the function :math:`f(x)=(6x-2)^2\sin(12x-4)` :cite:`forrester2008` by the :class:`.AnalyticDiscipline` .. GENERATED FROM PYTHON SOURCE LINES 50-55 .. code-block:: Python discipline = create_discipline( "AnalyticDiscipline", name="f", expressions={"y": "(6*x-2)**2*sin(12*x-4)"}, ) .. GENERATED FROM PYTHON SOURCE LINES 56-57 and seek to approximate it over the input space .. GENERATED FROM PYTHON SOURCE LINES 57-60 .. code-block:: Python input_space = create_design_space() input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0) .. GENERATED FROM PYTHON SOURCE LINES 61-63 To do this, we create a training dataset with 6 equispaced points: .. GENERATED FROM PYTHON SOURCE LINES 63-67 .. code-block:: Python training_dataset = sample_disciplines( [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6 ) .. rst-class:: sphx-glr-script-out .. code-block:: none INFO - 16:22:22: *** Start Sampling execution *** INFO - 16:22:22: Sampling INFO - 16:22:22: Disciplines: f INFO - 16:22:22: MDO formulation: MDF INFO - 16:22:22: Running the algorithm PYDOE_FULLFACT: INFO - 16:22:22: 17%|█▋ | 1/6 [00:00<00:00, 666.29 it/sec] INFO - 16:22:22: 33%|███▎ | 2/6 [00:00<00:00, 1081.28 it/sec] INFO - 16:22:22: 50%|█████ | 3/6 [00:00<00:00, 1418.91 it/sec] INFO - 16:22:22: 67%|██████▋ | 4/6 [00:00<00:00, 1696.38 it/sec] INFO - 16:22:22: 83%|████████▎ | 5/6 [00:00<00:00, 1927.53 it/sec] INFO - 16:22:22: 100%|██████████| 6/6 [00:00<00:00, 2051.84 it/sec] INFO - 16:22:22: *** End Sampling execution *** .. GENERATED FROM PYTHON SOURCE LINES 68-74 Basics ------ Training ~~~~~~~~ Then, we train an RBF regression model from these samples: .. GENERATED FROM PYTHON SOURCE LINES 74-77 .. code-block:: Python model = create_regression_model("RBFRegressor", training_dataset) model.learn() .. GENERATED FROM PYTHON SOURCE LINES 78-82 Prediction ~~~~~~~~~~ Once it is built, we can predict the output value of :math:`f` at a new input point: .. GENERATED FROM PYTHON SOURCE LINES 82-86 .. code-block:: Python input_value = {"x": array([0.65])} output_value = model.predict(input_value) output_value .. rst-class:: sphx-glr-script-out .. code-block:: none {'y': array([-2.16802353])} .. GENERATED FROM PYTHON SOURCE LINES 87-88 as well as its Jacobian value: .. GENERATED FROM PYTHON SOURCE LINES 88-91 .. code-block:: Python jacobian_value = model.predict_jacobian(input_value) jacobian_value .. rst-class:: sphx-glr-script-out .. code-block:: none {'y': {'x': array([[-45.81825011]])}} .. GENERATED FROM PYTHON SOURCE LINES 92-95 Plotting ~~~~~~~~ You can see that the RBF model is pretty good on the right, but bad on the left: .. GENERATED FROM PYTHON SOURCE LINES 95-107 .. code-block:: Python test_dataset = sample_disciplines( [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100 ) input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy() reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel() predicted_output_data = model.predict(input_data).ravel() plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_rbf_regression_001.png :alt: plot rbf regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_rbf_regression_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none INFO - 16:22:22: *** Start Sampling execution *** INFO - 16:22:22: Sampling INFO - 16:22:22: Disciplines: f INFO - 16:22:22: MDO formulation: MDF INFO - 16:22:22: Running the algorithm PYDOE_FULLFACT: INFO - 16:22:22: 1%| | 1/100 [00:00<00:00, 3916.25 it/sec] INFO - 16:22:22: 2%|▏ | 2/100 [00:00<00:00, 3721.65 it/sec] INFO - 16:22:22: 3%|▎ | 3/100 [00:00<00:00, 3818.79 it/sec] INFO - 16:22:22: 4%|▍ | 4/100 [00:00<00:00, 3797.47 it/sec] INFO - 16:22:22: 5%|▌ | 5/100 [00:00<00:00, 3880.02 it/sec] INFO - 16:22:22: 6%|▌ | 6/100 [00:00<00:00, 3954.40 it/sec] INFO - 16:22:22: 7%|▋ | 7/100 [00:00<00:00, 4004.38 it/sec] INFO - 16:22:22: 8%|▊ | 8/100 [00:00<00:00, 4000.29 it/sec] INFO - 16:22:22: 9%|▉ | 9/100 [00:00<00:00, 4040.76 it/sec] INFO - 16:22:22: 10%|█ | 10/100 [00:00<00:00, 4061.49 it/sec] INFO - 16:22:22: 11%|█ | 11/100 [00:00<00:00, 4085.48 it/sec] INFO - 16:22:22: 12%|█▏ | 12/100 [00:00<00:00, 4081.39 it/sec] INFO - 16:22:22: 13%|█▎ | 13/100 [00:00<00:00, 4100.00 it/sec] INFO - 16:22:22: 14%|█▍ | 14/100 [00:00<00:00, 4128.83 it/sec] INFO - 16:22:22: 15%|█▌ | 15/100 [00:00<00:00, 4153.60 it/sec] INFO - 16:22:22: 16%|█▌ | 16/100 [00:00<00:00, 4159.21 it/sec] INFO - 16:22:22: 17%|█▋ | 17/100 [00:00<00:00, 4168.07 it/sec] INFO - 16:22:22: 18%|█▊ | 18/100 [00:00<00:00, 4191.51 it/sec] INFO - 16:22:22: 19%|█▉ | 19/100 [00:00<00:00, 4207.15 it/sec] INFO - 16:22:22: 20%|██ | 20/100 [00:00<00:00, 4223.45 it/sec] INFO - 16:22:22: 21%|██ | 21/100 [00:00<00:00, 4216.19 it/sec] INFO - 16:22:22: 22%|██▏ | 22/100 [00:00<00:00, 4222.32 it/sec] INFO - 16:22:22: 23%|██▎ | 23/100 [00:00<00:00, 4233.88 it/sec] INFO - 16:22:22: 24%|██▍ | 24/100 [00:00<00:00, 4248.29 it/sec] INFO - 16:22:22: 25%|██▌ | 25/100 [00:00<00:00, 4249.55 it/sec] INFO - 16:22:22: 26%|██▌ | 26/100 [00:00<00:00, 4260.67 it/sec] INFO - 16:22:22: 27%|██▋ | 27/100 [00:00<00:00, 4273.12 it/sec] INFO - 16:22:22: 28%|██▊ | 28/100 [00:00<00:00, 4272.12 it/sec] INFO - 16:22:22: 29%|██▉ | 29/100 [00:00<00:00, 4284.27 it/sec] INFO - 16:22:22: 30%|███ | 30/100 [00:00<00:00, 4280.78 it/sec] INFO - 16:22:22: 31%|███ | 31/100 [00:00<00:00, 4290.35 it/sec] INFO - 16:22:22: 32%|███▏ | 32/100 [00:00<00:00, 4295.93 it/sec] INFO - 16:22:22: 33%|███▎ | 33/100 [00:00<00:00, 4304.79 it/sec] INFO - 16:22:22: 34%|███▍ | 34/100 [00:00<00:00, 4305.10 it/sec] INFO - 16:22:22: 35%|███▌ | 35/100 [00:00<00:00, 4311.07 it/sec] INFO - 16:22:22: 36%|███▌ | 36/100 [00:00<00:00, 4318.58 it/sec] INFO - 16:22:22: 37%|███▋ | 37/100 [00:00<00:00, 4326.56 it/sec] INFO - 16:22:22: 38%|███▊ | 38/100 [00:00<00:00, 4335.32 it/sec] INFO - 16:22:22: 39%|███▉ | 39/100 [00:00<00:00, 4332.73 it/sec] INFO - 16:22:22: 40%|████ | 40/100 [00:00<00:00, 4339.91 it/sec] INFO - 16:22:22: 41%|████ | 41/100 [00:00<00:00, 4347.42 it/sec] INFO - 16:22:22: 42%|████▏ | 42/100 [00:00<00:00, 4352.98 it/sec] INFO - 16:22:22: 43%|████▎ | 43/100 [00:00<00:00, 4350.20 it/sec] INFO - 16:22:22: 44%|████▍ | 44/100 [00:00<00:00, 4352.99 it/sec] INFO - 16:22:22: 45%|████▌ | 45/100 [00:00<00:00, 4359.18 it/sec] INFO - 16:22:22: 46%|████▌ | 46/100 [00:00<00:00, 4357.92 it/sec] INFO - 16:22:22: 47%|████▋ | 47/100 [00:00<00:00, 4363.36 it/sec] INFO - 16:22:22: 48%|████▊ | 48/100 [00:00<00:00, 4357.81 it/sec] INFO - 16:22:22: 49%|████▉ | 49/100 [00:00<00:00, 4359.34 it/sec] INFO - 16:22:22: 50%|█████ | 50/100 [00:00<00:00, 4360.53 it/sec] INFO - 16:22:22: 51%|█████ | 51/100 [00:00<00:00, 4364.16 it/sec] INFO - 16:22:22: 52%|█████▏ | 52/100 [00:00<00:00, 4361.73 it/sec] INFO - 16:22:22: 53%|█████▎ | 53/100 [00:00<00:00, 4364.18 it/sec] INFO - 16:22:22: 54%|█████▍ | 54/100 [00:00<00:00, 4369.32 it/sec] INFO - 16:22:22: 55%|█████▌ | 55/100 [00:00<00:00, 4372.79 it/sec] INFO - 16:22:22: 56%|█████▌ | 56/100 [00:00<00:00, 4370.94 it/sec] INFO - 16:22:22: 57%|█████▋ | 57/100 [00:00<00:00, 4371.46 it/sec] INFO - 16:22:22: 58%|█████▊ | 58/100 [00:00<00:00, 4373.54 it/sec] INFO - 16:22:22: 59%|█████▉ | 59/100 [00:00<00:00, 4376.79 it/sec] INFO - 16:22:22: 60%|██████ | 60/100 [00:00<00:00, 4381.16 it/sec] INFO - 16:22:22: 61%|██████ | 61/100 [00:00<00:00, 4341.12 it/sec] INFO - 16:22:22: 62%|██████▏ | 62/100 [00:00<00:00, 4340.48 it/sec] INFO - 16:22:22: 63%|██████▎ | 63/100 [00:00<00:00, 4339.22 it/sec] INFO - 16:22:22: 64%|██████▍ | 64/100 [00:00<00:00, 4341.44 it/sec] INFO - 16:22:22: 65%|██████▌ | 65/100 [00:00<00:00, 4337.30 it/sec] INFO - 16:22:22: 66%|██████▌ | 66/100 [00:00<00:00, 4340.50 it/sec] INFO - 16:22:22: 67%|██████▋ | 67/100 [00:00<00:00, 4342.53 it/sec] INFO - 16:22:22: 68%|██████▊ | 68/100 [00:00<00:00, 4345.30 it/sec] INFO - 16:22:22: 69%|██████▉ | 69/100 [00:00<00:00, 4343.69 it/sec] INFO - 16:22:22: 70%|███████ | 70/100 [00:00<00:00, 4345.85 it/sec] INFO - 16:22:22: 71%|███████ | 71/100 [00:00<00:00, 4349.86 it/sec] INFO - 16:22:22: 72%|███████▏ | 72/100 [00:00<00:00, 4352.63 it/sec] INFO - 16:22:22: 73%|███████▎ | 73/100 [00:00<00:00, 4355.95 it/sec] INFO - 16:22:22: 74%|███████▍ | 74/100 [00:00<00:00, 4353.44 it/sec] INFO - 16:22:22: 75%|███████▌ | 75/100 [00:00<00:00, 4355.94 it/sec] INFO - 16:22:22: 76%|███████▌ | 76/100 [00:00<00:00, 4358.61 it/sec] INFO - 16:22:22: 77%|███████▋ | 77/100 [00:00<00:00, 4362.04 it/sec] INFO - 16:22:22: 78%|███████▊ | 78/100 [00:00<00:00, 4360.74 it/sec] INFO - 16:22:22: 79%|███████▉ | 79/100 [00:00<00:00, 4362.34 it/sec] INFO - 16:22:22: 80%|████████ | 80/100 [00:00<00:00, 4365.32 it/sec] INFO - 16:22:22: 81%|████████ | 81/100 [00:00<00:00, 4365.59 it/sec] INFO - 16:22:22: 82%|████████▏ | 82/100 [00:00<00:00, 4368.07 it/sec] INFO - 16:22:22: 83%|████████▎ | 83/100 [00:00<00:00, 4365.94 it/sec] INFO - 16:22:22: 84%|████████▍ | 84/100 [00:00<00:00, 4367.88 it/sec] INFO - 16:22:22: 85%|████████▌ | 85/100 [00:00<00:00, 4371.37 it/sec] INFO - 16:22:22: 86%|████████▌ | 86/100 [00:00<00:00, 4373.89 it/sec] INFO - 16:22:22: 87%|████████▋ | 87/100 [00:00<00:00, 4371.84 it/sec] INFO - 16:22:22: 88%|████████▊ | 88/100 [00:00<00:00, 4373.05 it/sec] INFO - 16:22:22: 89%|████████▉ | 89/100 [00:00<00:00, 4374.39 it/sec] INFO - 16:22:22: 90%|█████████ | 90/100 [00:00<00:00, 4376.11 it/sec] INFO - 16:22:22: 91%|█████████ | 91/100 [00:00<00:00, 4378.59 it/sec] INFO - 16:22:22: 92%|█████████▏| 92/100 [00:00<00:00, 4375.41 it/sec] INFO - 16:22:22: 93%|█████████▎| 93/100 [00:00<00:00, 4377.50 it/sec] INFO - 16:22:22: 94%|█████████▍| 94/100 [00:00<00:00, 4380.23 it/sec] INFO - 16:22:22: 95%|█████████▌| 95/100 [00:00<00:00, 4383.15 it/sec] INFO - 16:22:22: 96%|█████████▌| 96/100 [00:00<00:00, 4377.19 it/sec] INFO - 16:22:22: 97%|█████████▋| 97/100 [00:00<00:00, 4377.86 it/sec] INFO - 16:22:22: 98%|█████████▊| 98/100 [00:00<00:00, 4378.42 it/sec] INFO - 16:22:22: 99%|█████████▉| 99/100 [00:00<00:00, 4380.27 it/sec] INFO - 16:22:22: 100%|██████████| 100/100 [00:00<00:00, 4323.71 it/sec] INFO - 16:22:22: *** End Sampling execution *** .. GENERATED FROM PYTHON SOURCE LINES 108-120 Settings -------- The :class:`.RBFRegressor` has many options defined in the :class:`.RBFRegressor_Settings` Pydantic model. Function ~~~~~~~~ The default RBF is the multiquadratic function :math:`\sqrt{(r/\epsilon)^2 + 1}` depending on a radius :math:`r` representing a distance between two points and an adjustable constant :math:`\epsilon`. The RBF can be changed using the ``function`` option, which can be either an :class:`.RBF`: .. GENERATED FROM PYTHON SOURCE LINES 120-125 .. code-block:: Python model = create_regression_model("RBFRegressor", training_dataset, function=RBF.GAUSSIAN) model.learn() predicted_output_data_g = model.predict(input_data).ravel() .. GENERATED FROM PYTHON SOURCE LINES 126-127 or a Python function: .. GENERATED FROM PYTHON SOURCE LINES 127-144 .. code-block:: Python def rbf(self, r: float) -> float: """Evaluate a cubic RBF. An RBF must take 2 arguments, namely ``(self, r)``. Args: r: The radius. Returns: The RBF value. """ return r**3 model = create_regression_model("RBFRegressor", training_dataset, function=rbf) model.learn() predicted_output_data_c = model.predict(input_data).ravel() .. GENERATED FROM PYTHON SOURCE LINES 145-146 We can see that the predictions are different: .. GENERATED FROM PYTHON SOURCE LINES 146-154 .. code-block:: Python plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.plot(input_data.ravel(), predicted_output_data_g, label="Regression - Gaussian RBF") plt.plot(input_data.ravel(), predicted_output_data_c, label="Regression - Cubic RBF") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_rbf_regression_002.png :alt: plot rbf regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_rbf_regression_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 155-162 Epsilon ~~~~~~~ Some RBFs depend on an ``epsilon`` parameter whose default value is the average distance between input data. This is the case of ``"multiquadric"``, ``"gaussian"`` and ``"inverse"`` RBFs. For example, we can train a first multiquadric RBF model with an ``epsilon`` set to 0.5 .. GENERATED FROM PYTHON SOURCE LINES 162-165 .. code-block:: Python model = create_regression_model("RBFRegressor", training_dataset, epsilon=0.5) model.learn() predicted_output_data_1 = model.predict(input_data).ravel() .. GENERATED FROM PYTHON SOURCE LINES 166-167 a second one with an ``epsilon`` set to 1.0: .. GENERATED FROM PYTHON SOURCE LINES 167-170 .. code-block:: Python model = create_regression_model("RBFRegressor", training_dataset, epsilon=1.0) model.learn() predicted_output_data_2 = model.predict(input_data).ravel() .. GENERATED FROM PYTHON SOURCE LINES 171-172 and a last one with an ``epsilon`` set to 2.0: .. GENERATED FROM PYTHON SOURCE LINES 172-175 .. code-block:: Python model = create_regression_model("RBFRegressor", training_dataset, epsilon=2.0) model.learn() predicted_output_data_3 = model.predict(input_data).ravel() .. GENERATED FROM PYTHON SOURCE LINES 176-177 and see that this parameter represents the regularity of the regression model: .. GENERATED FROM PYTHON SOURCE LINES 177-186 .. code-block:: Python plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.plot(input_data.ravel(), predicted_output_data_1, label="Regression - Epsilon(0.5)") plt.plot(input_data.ravel(), predicted_output_data_2, label="Regression - Epsilon(1)") plt.plot(input_data.ravel(), predicted_output_data_3, label="Regression - Epsilon(2)") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_rbf_regression_003.png :alt: plot rbf regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_rbf_regression_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 187-193 Smooth ~~~~~~ By default, an RBF model interpolates the training points. This is parametrized by the ``smooth`` option which is set to 0. We can increase the smoothness of the model by increasing this value: .. GENERATED FROM PYTHON SOURCE LINES 193-196 .. code-block:: Python model = create_regression_model("RBFRegressor", training_dataset, smooth=0.1) model.learn() predicted_output_data_ = model.predict(input_data).ravel() .. GENERATED FROM PYTHON SOURCE LINES 197-198 and see that the model is not interpolating: .. GENERATED FROM PYTHON SOURCE LINES 198-205 .. code-block:: Python plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - Smooth") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_rbf_regression_004.png :alt: plot rbf regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_rbf_regression_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 206-213 Thin plate spline (TPS) ----------------------- TPS regression is a specific case of RBF regression where the RBF is the thin plate radial basis function for :math:`r^2\log(r)`. The :class:`.TPSRegressor` class deriving from :class:`.RBFRegressor` implements this case: .. GENERATED FROM PYTHON SOURCE LINES 213-216 .. code-block:: Python model = create_regression_model("TPSRegressor", training_dataset) model.learn() predicted_output_data_ = model.predict(input_data).ravel() .. GENERATED FROM PYTHON SOURCE LINES 217-219 We can see that the difference between this model and the default multiquadric RBF model: .. GENERATED FROM PYTHON SOURCE LINES 219-226 .. code-block:: Python plt.plot(input_data.ravel(), reference_output_data, label="Reference") plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics") plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - TPS") plt.grid() plt.legend() plt.show() .. image-sg:: /examples/mlearning/regression_model/images/sphx_glr_plot_rbf_regression_005.png :alt: plot rbf regression :srcset: /examples/mlearning/regression_model/images/sphx_glr_plot_rbf_regression_005.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 227-228 The :class:`.TPSRegressor` can be customized with the :class:`.TPSRegressor_Settings`. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.326 seconds) .. _sphx_glr_download_examples_mlearning_regression_model_plot_rbf_regression.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_rbf_regression.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_rbf_regression.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_rbf_regression.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_