Note
Go to the end to download the full example code.
Radial basis function (RBF) regression#
An RBFRegressor is an RBF model
based on SciPy.
See also
You can find more information about RBF models on this wikipedia page.
from __future__ import annotations
from matplotlib import pyplot as plt
from numpy import array
from gemseo import create_design_space
from gemseo import create_discipline
from gemseo import sample_disciplines
from gemseo.mlearning import create_regression_model
from gemseo.mlearning.regression.algos.rbf_settings import RBF
Problem#
In this example,
we represent the function \(f(x)=(6x-2)^2\sin(12x-4)\) [FSK08]
by the AnalyticDiscipline
discipline = create_discipline(
"AnalyticDiscipline",
name="f",
expressions={"y": "(6*x-2)**2*sin(12*x-4)"},
)
and seek to approximate it over the input space
input_space = create_design_space()
input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0)
To do this, we create a training dataset with 6 equispaced points:
training_dataset = sample_disciplines(
[discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6
)
INFO - 16:21:48: *** Start Sampling execution ***
INFO - 16:21:48: Sampling
INFO - 16:21:48: Disciplines: f
INFO - 16:21:48: MDO formulation: MDF
INFO - 16:21:48: Running the algorithm PYDOE_FULLFACT:
INFO - 16:21:48: 17%|█▋ | 1/6 [00:00<00:00, 690.08 it/sec]
INFO - 16:21:48: 33%|███▎ | 2/6 [00:00<00:00, 1105.22 it/sec]
INFO - 16:21:48: 50%|█████ | 3/6 [00:00<00:00, 1432.32 it/sec]
INFO - 16:21:48: 67%|██████▋ | 4/6 [00:00<00:00, 1706.91 it/sec]
INFO - 16:21:48: 83%|████████▎ | 5/6 [00:00<00:00, 1924.52 it/sec]
INFO - 16:21:48: 100%|██████████| 6/6 [00:00<00:00, 2059.73 it/sec]
INFO - 16:21:48: *** End Sampling execution ***
Basics#
Training#
Then, we train an RBF regression model from these samples:
model = create_regression_model("RBFRegressor", training_dataset)
model.learn()
Prediction#
Once it is built, we can predict the output value of \(f\) at a new input point:
input_value = {"x": array([0.65])}
output_value = model.predict(input_value)
output_value
{'y': array([-2.16802353])}
as well as its Jacobian value:
jacobian_value = model.predict_jacobian(input_value)
jacobian_value
{'y': {'x': array([[-45.81825011]])}}
Plotting#
You can see that the RBF model is pretty good on the right, but bad on the left:
test_dataset = sample_disciplines(
[discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100
)
input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy()
reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel()
predicted_output_data = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.grid()
plt.legend()
plt.show()

INFO - 16:21:48: *** Start Sampling execution ***
INFO - 16:21:48: Sampling
INFO - 16:21:48: Disciplines: f
INFO - 16:21:48: MDO formulation: MDF
INFO - 16:21:48: Running the algorithm PYDOE_FULLFACT:
INFO - 16:21:48: 1%| | 1/100 [00:00<00:00, 3971.88 it/sec]
INFO - 16:21:48: 2%|▏ | 2/100 [00:00<00:00, 3733.25 it/sec]
INFO - 16:21:48: 3%|▎ | 3/100 [00:00<00:00, 3813.00 it/sec]
INFO - 16:21:48: 4%|▍ | 4/100 [00:00<00:00, 3906.22 it/sec]
INFO - 16:21:48: 5%|▌ | 5/100 [00:00<00:00, 3885.77 it/sec]
INFO - 16:21:48: 6%|▌ | 6/100 [00:00<00:00, 3933.39 it/sec]
INFO - 16:21:48: 7%|▋ | 7/100 [00:00<00:00, 3958.49 it/sec]
INFO - 16:21:48: 8%|▊ | 8/100 [00:00<00:00, 4002.68 it/sec]
INFO - 16:21:48: 9%|▉ | 9/100 [00:00<00:00, 3963.95 it/sec]
INFO - 16:21:48: 10%|█ | 10/100 [00:00<00:00, 3998.76 it/sec]
INFO - 16:21:48: 11%|█ | 11/100 [00:00<00:00, 4035.45 it/sec]
INFO - 16:21:48: 12%|█▏ | 12/100 [00:00<00:00, 4073.79 it/sec]
INFO - 16:21:48: 13%|█▎ | 13/100 [00:00<00:00, 4070.62 it/sec]
INFO - 16:21:48: 14%|█▍ | 14/100 [00:00<00:00, 4095.43 it/sec]
INFO - 16:21:48: 15%|█▌ | 15/100 [00:00<00:00, 4120.41 it/sec]
INFO - 16:21:48: 16%|█▌ | 16/100 [00:00<00:00, 4146.11 it/sec]
INFO - 16:21:48: 17%|█▋ | 17/100 [00:00<00:00, 4152.53 it/sec]
INFO - 16:21:48: 18%|█▊ | 18/100 [00:00<00:00, 4164.92 it/sec]
INFO - 16:21:48: 19%|█▉ | 19/100 [00:00<00:00, 4182.64 it/sec]
INFO - 16:21:48: 20%|██ | 20/100 [00:00<00:00, 4200.82 it/sec]
INFO - 16:21:48: 21%|██ | 21/100 [00:00<00:00, 4215.38 it/sec]
INFO - 16:21:48: 22%|██▏ | 22/100 [00:00<00:00, 4211.15 it/sec]
INFO - 16:21:48: 23%|██▎ | 23/100 [00:00<00:00, 4219.07 it/sec]
INFO - 16:21:48: 24%|██▍ | 24/100 [00:00<00:00, 4229.37 it/sec]
INFO - 16:21:48: 25%|██▌ | 25/100 [00:00<00:00, 4238.04 it/sec]
INFO - 16:21:48: 26%|██▌ | 26/100 [00:00<00:00, 4235.52 it/sec]
INFO - 16:21:48: 27%|██▋ | 27/100 [00:00<00:00, 4242.86 it/sec]
INFO - 16:21:48: 28%|██▊ | 28/100 [00:00<00:00, 4252.63 it/sec]
INFO - 16:21:48: 29%|██▉ | 29/100 [00:00<00:00, 4264.00 it/sec]
INFO - 16:21:48: 30%|███ | 30/100 [00:00<00:00, 4274.52 it/sec]
INFO - 16:21:48: 31%|███ | 31/100 [00:00<00:00, 4270.34 it/sec]
INFO - 16:21:48: 32%|███▏ | 32/100 [00:00<00:00, 4280.72 it/sec]
INFO - 16:21:48: 33%|███▎ | 33/100 [00:00<00:00, 4288.26 it/sec]
INFO - 16:21:48: 34%|███▍ | 34/100 [00:00<00:00, 4298.09 it/sec]
INFO - 16:21:48: 35%|███▌ | 35/100 [00:00<00:00, 4296.94 it/sec]
INFO - 16:21:48: 36%|███▌ | 36/100 [00:00<00:00, 4299.16 it/sec]
INFO - 16:21:48: 37%|███▋ | 37/100 [00:00<00:00, 4301.13 it/sec]
INFO - 16:21:48: 38%|███▊ | 38/100 [00:00<00:00, 4307.31 it/sec]
INFO - 16:21:48: 39%|███▉ | 39/100 [00:00<00:00, 4313.99 it/sec]
INFO - 16:21:48: 40%|████ | 40/100 [00:00<00:00, 4310.36 it/sec]
INFO - 16:21:48: 41%|████ | 41/100 [00:00<00:00, 4316.10 it/sec]
INFO - 16:21:48: 42%|████▏ | 42/100 [00:00<00:00, 4313.65 it/sec]
INFO - 16:21:48: 43%|████▎ | 43/100 [00:00<00:00, 4314.10 it/sec]
INFO - 16:21:48: 44%|████▍ | 44/100 [00:00<00:00, 4308.88 it/sec]
INFO - 16:21:48: 45%|████▌ | 45/100 [00:00<00:00, 4311.87 it/sec]
INFO - 16:21:48: 46%|████▌ | 46/100 [00:00<00:00, 4315.22 it/sec]
INFO - 16:21:48: 47%|████▋ | 47/100 [00:00<00:00, 4319.57 it/sec]
INFO - 16:21:48: 48%|████▊ | 48/100 [00:00<00:00, 4316.24 it/sec]
INFO - 16:21:48: 49%|████▉ | 49/100 [00:00<00:00, 4316.12 it/sec]
INFO - 16:21:48: 50%|█████ | 50/100 [00:00<00:00, 4320.91 it/sec]
INFO - 16:21:48: 51%|█████ | 51/100 [00:00<00:00, 4325.69 it/sec]
INFO - 16:21:48: 52%|█████▏ | 52/100 [00:00<00:00, 4330.21 it/sec]
INFO - 16:21:48: 53%|█████▎ | 53/100 [00:00<00:00, 4326.80 it/sec]
INFO - 16:21:48: 54%|█████▍ | 54/100 [00:00<00:00, 4329.89 it/sec]
INFO - 16:21:48: 55%|█████▌ | 55/100 [00:00<00:00, 4334.02 it/sec]
INFO - 16:21:48: 56%|█████▌ | 56/100 [00:00<00:00, 4336.40 it/sec]
INFO - 16:21:48: 57%|█████▋ | 57/100 [00:00<00:00, 4334.61 it/sec]
INFO - 16:21:48: 58%|█████▊ | 58/100 [00:00<00:00, 4336.20 it/sec]
INFO - 16:21:48: 59%|█████▉ | 59/100 [00:00<00:00, 4340.25 it/sec]
INFO - 16:21:48: 60%|██████ | 60/100 [00:00<00:00, 4338.86 it/sec]
INFO - 16:21:48: 61%|██████ | 61/100 [00:00<00:00, 4306.77 it/sec]
INFO - 16:21:48: 62%|██████▏ | 62/100 [00:00<00:00, 4305.98 it/sec]
INFO - 16:21:48: 63%|██████▎ | 63/100 [00:00<00:00, 4306.06 it/sec]
INFO - 16:21:48: 64%|██████▍ | 64/100 [00:00<00:00, 4309.38 it/sec]
INFO - 16:21:48: 65%|██████▌ | 65/100 [00:00<00:00, 4312.60 it/sec]
INFO - 16:21:48: 66%|██████▌ | 66/100 [00:00<00:00, 4309.55 it/sec]
INFO - 16:21:48: 67%|██████▋ | 67/100 [00:00<00:00, 4312.15 it/sec]
INFO - 16:21:48: 68%|██████▊ | 68/100 [00:00<00:00, 4315.98 it/sec]
INFO - 16:21:48: 69%|██████▉ | 69/100 [00:00<00:00, 4319.96 it/sec]
INFO - 16:21:48: 70%|███████ | 70/100 [00:00<00:00, 4318.75 it/sec]
INFO - 16:21:48: 71%|███████ | 71/100 [00:00<00:00, 4320.64 it/sec]
INFO - 16:21:48: 72%|███████▏ | 72/100 [00:00<00:00, 4324.52 it/sec]
INFO - 16:21:48: 73%|███████▎ | 73/100 [00:00<00:00, 4326.22 it/sec]
INFO - 16:21:48: 74%|███████▍ | 74/100 [00:00<00:00, 4324.45 it/sec]
INFO - 16:21:48: 75%|███████▌ | 75/100 [00:00<00:00, 4324.92 it/sec]
INFO - 16:21:48: 76%|███████▌ | 76/100 [00:00<00:00, 4327.61 it/sec]
INFO - 16:21:48: 77%|███████▋ | 77/100 [00:00<00:00, 4327.21 it/sec]
INFO - 16:21:48: 78%|███████▊ | 78/100 [00:00<00:00, 4329.92 it/sec]
INFO - 16:21:48: 79%|███████▉ | 79/100 [00:00<00:00, 4326.23 it/sec]
INFO - 16:21:48: 80%|████████ | 80/100 [00:00<00:00, 4327.59 it/sec]
INFO - 16:21:48: 81%|████████ | 81/100 [00:00<00:00, 4330.75 it/sec]
INFO - 16:21:48: 82%|████████▏ | 82/100 [00:00<00:00, 4334.05 it/sec]
INFO - 16:21:48: 83%|████████▎ | 83/100 [00:00<00:00, 4333.39 it/sec]
INFO - 16:21:48: 84%|████████▍ | 84/100 [00:00<00:00, 4335.52 it/sec]
INFO - 16:21:48: 85%|████████▌ | 85/100 [00:00<00:00, 4338.07 it/sec]
INFO - 16:21:48: 86%|████████▌ | 86/100 [00:00<00:00, 4340.78 it/sec]
INFO - 16:21:48: 87%|████████▋ | 87/100 [00:00<00:00, 4343.79 it/sec]
INFO - 16:21:48: 88%|████████▊ | 88/100 [00:00<00:00, 4341.06 it/sec]
INFO - 16:21:48: 89%|████████▉ | 89/100 [00:00<00:00, 4343.50 it/sec]
INFO - 16:21:48: 90%|█████████ | 90/100 [00:00<00:00, 4346.43 it/sec]
INFO - 16:21:48: 91%|█████████ | 91/100 [00:00<00:00, 4349.45 it/sec]
INFO - 16:21:48: 92%|█████████▏| 92/100 [00:00<00:00, 4347.60 it/sec]
INFO - 16:21:48: 93%|█████████▎| 93/100 [00:00<00:00, 4348.85 it/sec]
INFO - 16:21:48: 94%|█████████▍| 94/100 [00:00<00:00, 4350.94 it/sec]
INFO - 16:21:48: 95%|█████████▌| 95/100 [00:00<00:00, 4351.51 it/sec]
INFO - 16:21:48: 96%|█████████▌| 96/100 [00:00<00:00, 4353.43 it/sec]
INFO - 16:21:48: 97%|█████████▋| 97/100 [00:00<00:00, 4350.89 it/sec]
INFO - 16:21:48: 98%|█████████▊| 98/100 [00:00<00:00, 4353.29 it/sec]
INFO - 16:21:48: 99%|█████████▉| 99/100 [00:00<00:00, 4355.18 it/sec]
INFO - 16:21:48: 100%|██████████| 100/100 [00:00<00:00, 4304.94 it/sec]
INFO - 16:21:48: *** End Sampling execution ***
Settings#
The RBFRegressor has many options
defined in the RBFRegressor_Settings Pydantic model.
Function#
The default RBF is the multiquadratic function \(\sqrt{(r/\epsilon)^2 + 1}\)
depending on a radius \(r\) representing a distance between two points
and an adjustable constant \(\epsilon\).
The RBF can be changed using the function option,
which can be either an RBF:
model = create_regression_model("RBFRegressor", training_dataset, function=RBF.GAUSSIAN)
model.learn()
predicted_output_data_g = model.predict(input_data).ravel()
or a Python function:
def rbf(self, r: float) -> float:
"""Evaluate a cubic RBF.
An RBF must take 2 arguments, namely ``(self, r)``.
Args:
r: The radius.
Returns:
The RBF value.
"""
return r**3
model = create_regression_model("RBFRegressor", training_dataset, function=rbf)
model.learn()
predicted_output_data_c = model.predict(input_data).ravel()
We can see that the predictions are different:
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_g, label="Regression - Gaussian RBF")
plt.plot(input_data.ravel(), predicted_output_data_c, label="Regression - Cubic RBF")
plt.grid()
plt.legend()
plt.show()

Epsilon#
Some RBFs depend on an epsilon parameter
whose default value is the average distance between input data.
This is the case of "multiquadric", "gaussian" and "inverse" RBFs.
For example,
we can train a first multiquadric RBF model with an epsilon set to 0.5
model = create_regression_model("RBFRegressor", training_dataset, epsilon=0.5)
model.learn()
predicted_output_data_1 = model.predict(input_data).ravel()
a second one with an epsilon set to 1.0:
model = create_regression_model("RBFRegressor", training_dataset, epsilon=1.0)
model.learn()
predicted_output_data_2 = model.predict(input_data).ravel()
and a last one with an epsilon set to 2.0:
model = create_regression_model("RBFRegressor", training_dataset, epsilon=2.0)
model.learn()
predicted_output_data_3 = model.predict(input_data).ravel()
and see that this parameter represents the regularity of the regression model:
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_1, label="Regression - Epsilon(0.5)")
plt.plot(input_data.ravel(), predicted_output_data_2, label="Regression - Epsilon(1)")
plt.plot(input_data.ravel(), predicted_output_data_3, label="Regression - Epsilon(2)")
plt.grid()
plt.legend()
plt.show()

Smooth#
By default,
an RBF model interpolates the training points.
This is parametrized by the smooth option which is set to 0.
We can increase the smoothness of the model by increasing this value:
model = create_regression_model("RBFRegressor", training_dataset, smooth=0.1)
model.learn()
predicted_output_data_ = model.predict(input_data).ravel()
and see that the model is not interpolating:
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - Smooth")
plt.grid()
plt.legend()
plt.show()

Thin plate spline (TPS)#
TPS regression is a specific case of RBF regression
where the RBF is the thin plate radial basis function for \(r^2\log(r)\).
The TPSRegressor class
deriving from RBFRegressor
implements this case:
model = create_regression_model("TPSRegressor", training_dataset)
model.learn()
predicted_output_data_ = model.predict(input_data).ravel()
We can see that the difference between this model and the default multiquadric RBF model:
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - TPS")
plt.grid()
plt.legend()
plt.show()

The TPSRegressor can be customized with the TPSRegressor_Settings.
Total running time of the script: (0 minutes 0.316 seconds)