Note
Go to the end to download the full example code.
Radial basis function (RBF) regression#
An RBFRegressor
is an RBF model
based on SciPy.
See also
You can find more information about RBF models on this wikipedia page.
from __future__ import annotations
from matplotlib import pyplot as plt
from numpy import array
from gemseo import configure_logger
from gemseo import create_design_space
from gemseo import create_discipline
from gemseo import sample_disciplines
from gemseo.mlearning import create_regression_model
from gemseo.mlearning.regression.algos.rbf_settings import RBF
configure_logger()
<RootLogger root (INFO)>
Problem#
In this example,
we represent the function \(f(x)=(6x-2)^2\sin(12x-4)\) [FSK08]
by the AnalyticDiscipline
discipline = create_discipline(
"AnalyticDiscipline",
name="f",
expressions={"y": "(6*x-2)**2*sin(12*x-4)"},
)
and seek to approximate it over the input space
input_space = create_design_space()
input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0)
To do this, we create a training dataset with 6 equispaced points:
training_dataset = sample_disciplines(
[discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6
)
WARNING - 02:38:19: No coupling in MDA, switching chain_linearize to True.
INFO - 02:38:19: *** Start Sampling execution ***
INFO - 02:38:19: Sampling
INFO - 02:38:19: Disciplines: f
INFO - 02:38:19: MDO formulation: MDF
INFO - 02:38:19: Running the algorithm PYDOE_FULLFACT:
INFO - 02:38:19: 17%|█▋ | 1/6 [00:00<00:00, 575.59 it/sec]
INFO - 02:38:19: 33%|███▎ | 2/6 [00:00<00:00, 930.52 it/sec]
INFO - 02:38:19: 50%|█████ | 3/6 [00:00<00:00, 1199.74 it/sec]
INFO - 02:38:19: 67%|██████▋ | 4/6 [00:00<00:00, 1401.96 it/sec]
INFO - 02:38:19: 83%|████████▎ | 5/6 [00:00<00:00, 1582.16 it/sec]
INFO - 02:38:19: 100%|██████████| 6/6 [00:00<00:00, 1735.45 it/sec]
INFO - 02:38:19: *** End Sampling execution (time: 0:00:00.004561) ***
Basics#
Training#
Then, we train an RBF regression model from these samples:
model = create_regression_model("RBFRegressor", training_dataset)
model.learn()
Prediction#
Once it is built, we can predict the output value of \(f\) at a new input point:
input_value = {"x": array([0.65])}
output_value = model.predict(input_value)
output_value
{'y': array([-2.16802353])}
as well as its Jacobian value:
jacobian_value = model.predict_jacobian(input_value)
jacobian_value
{'y': {'x': array([[-45.81825011]])}}
Plotting#
You can see that the RBF model is pretty good on the right, but bad on the left:
test_dataset = sample_disciplines(
[discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100
)
input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy()
reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel()
predicted_output_data = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.grid()
plt.legend()
plt.show()

WARNING - 02:38:19: No coupling in MDA, switching chain_linearize to True.
INFO - 02:38:19: *** Start Sampling execution ***
INFO - 02:38:19: Sampling
INFO - 02:38:19: Disciplines: f
INFO - 02:38:19: MDO formulation: MDF
INFO - 02:38:19: Running the algorithm PYDOE_FULLFACT:
INFO - 02:38:19: 1%| | 1/100 [00:00<00:00, 3339.41 it/sec]
INFO - 02:38:19: 2%|▏ | 2/100 [00:00<00:00, 3002.37 it/sec]
INFO - 02:38:19: 3%|▎ | 3/100 [00:00<00:00, 2894.62 it/sec]
INFO - 02:38:19: 4%|▍ | 4/100 [00:00<00:00, 2970.47 it/sec]
INFO - 02:38:19: 5%|▌ | 5/100 [00:00<00:00, 3009.69 it/sec]
INFO - 02:38:19: 6%|▌ | 6/100 [00:00<00:00, 2918.45 it/sec]
INFO - 02:38:19: 7%|▋ | 7/100 [00:00<00:00, 2943.37 it/sec]
INFO - 02:38:19: 8%|▊ | 8/100 [00:00<00:00, 2985.27 it/sec]
INFO - 02:38:19: 9%|▉ | 9/100 [00:00<00:00, 2995.93 it/sec]
INFO - 02:38:19: 10%|█ | 10/100 [00:00<00:00, 3033.85 it/sec]
INFO - 02:38:19: 11%|█ | 11/100 [00:00<00:00, 3056.87 it/sec]
INFO - 02:38:19: 12%|█▏ | 12/100 [00:00<00:00, 3061.72 it/sec]
INFO - 02:38:19: 13%|█▎ | 13/100 [00:00<00:00, 3081.78 it/sec]
INFO - 02:38:19: 14%|█▍ | 14/100 [00:00<00:00, 3107.06 it/sec]
INFO - 02:38:19: 15%|█▌ | 15/100 [00:00<00:00, 3116.59 it/sec]
INFO - 02:38:19: 16%|█▌ | 16/100 [00:00<00:00, 3134.90 it/sec]
INFO - 02:38:19: 17%|█▋ | 17/100 [00:00<00:00, 3154.87 it/sec]
INFO - 02:38:19: 18%|█▊ | 18/100 [00:00<00:00, 3169.63 it/sec]
INFO - 02:38:19: 19%|█▉ | 19/100 [00:00<00:00, 3153.86 it/sec]
INFO - 02:38:19: 20%|██ | 20/100 [00:00<00:00, 3168.02 it/sec]
INFO - 02:38:19: 21%|██ | 21/100 [00:00<00:00, 3172.70 it/sec]
INFO - 02:38:19: 22%|██▏ | 22/100 [00:00<00:00, 3170.08 it/sec]
INFO - 02:38:19: 23%|██▎ | 23/100 [00:00<00:00, 3181.17 it/sec]
INFO - 02:38:19: 24%|██▍ | 24/100 [00:00<00:00, 3194.95 it/sec]
INFO - 02:38:19: 25%|██▌ | 25/100 [00:00<00:00, 3199.41 it/sec]
INFO - 02:38:19: 26%|██▌ | 26/100 [00:00<00:00, 3208.82 it/sec]
INFO - 02:38:19: 27%|██▋ | 27/100 [00:00<00:00, 3217.77 it/sec]
INFO - 02:38:19: 28%|██▊ | 28/100 [00:00<00:00, 3225.94 it/sec]
INFO - 02:38:19: 29%|██▉ | 29/100 [00:00<00:00, 3225.10 it/sec]
INFO - 02:38:19: 30%|███ | 30/100 [00:00<00:00, 3234.18 it/sec]
INFO - 02:38:19: 31%|███ | 31/100 [00:00<00:00, 3239.49 it/sec]
INFO - 02:38:19: 32%|███▏ | 32/100 [00:00<00:00, 3235.41 it/sec]
INFO - 02:38:19: 33%|███▎ | 33/100 [00:00<00:00, 3239.45 it/sec]
INFO - 02:38:19: 34%|███▍ | 34/100 [00:00<00:00, 3240.61 it/sec]
INFO - 02:38:19: 35%|███▌ | 35/100 [00:00<00:00, 3238.63 it/sec]
INFO - 02:38:19: 36%|███▌ | 36/100 [00:00<00:00, 3240.72 it/sec]
INFO - 02:38:19: 37%|███▋ | 37/100 [00:00<00:00, 3245.89 it/sec]
INFO - 02:38:19: 38%|███▊ | 38/100 [00:00<00:00, 3252.19 it/sec]
INFO - 02:38:19: 39%|███▉ | 39/100 [00:00<00:00, 3241.09 it/sec]
INFO - 02:38:19: 40%|████ | 40/100 [00:00<00:00, 3245.17 it/sec]
INFO - 02:38:19: 41%|████ | 41/100 [00:00<00:00, 3248.63 it/sec]
INFO - 02:38:19: 42%|████▏ | 42/100 [00:00<00:00, 3243.68 it/sec]
INFO - 02:38:19: 43%|████▎ | 43/100 [00:00<00:00, 3247.07 it/sec]
INFO - 02:38:19: 44%|████▍ | 44/100 [00:00<00:00, 3251.63 it/sec]
INFO - 02:38:19: 45%|████▌ | 45/100 [00:00<00:00, 3251.17 it/sec]
INFO - 02:38:19: 46%|████▌ | 46/100 [00:00<00:00, 3254.85 it/sec]
INFO - 02:38:19: 47%|████▋ | 47/100 [00:00<00:00, 3255.80 it/sec]
INFO - 02:38:19: 48%|████▊ | 48/100 [00:00<00:00, 3253.61 it/sec]
INFO - 02:38:19: 49%|████▉ | 49/100 [00:00<00:00, 3255.05 it/sec]
INFO - 02:38:19: 50%|█████ | 50/100 [00:00<00:00, 3259.23 it/sec]
INFO - 02:38:19: 51%|█████ | 51/100 [00:00<00:00, 3263.10 it/sec]
INFO - 02:38:19: 52%|█████▏ | 52/100 [00:00<00:00, 3262.44 it/sec]
INFO - 02:38:19: 53%|█████▎ | 53/100 [00:00<00:00, 3264.15 it/sec]
INFO - 02:38:19: 54%|█████▍ | 54/100 [00:00<00:00, 3267.49 it/sec]
INFO - 02:38:19: 55%|█████▌ | 55/100 [00:00<00:00, 3266.82 it/sec]
INFO - 02:38:19: 56%|█████▌ | 56/100 [00:00<00:00, 3269.46 it/sec]
INFO - 02:38:19: 57%|█████▋ | 57/100 [00:00<00:00, 3272.54 it/sec]
INFO - 02:38:19: 58%|█████▊ | 58/100 [00:00<00:00, 3272.13 it/sec]
INFO - 02:38:19: 59%|█████▉ | 59/100 [00:00<00:00, 3272.94 it/sec]
INFO - 02:38:19: 60%|██████ | 60/100 [00:00<00:00, 3276.37 it/sec]
INFO - 02:38:19: 61%|██████ | 61/100 [00:00<00:00, 3276.09 it/sec]
INFO - 02:38:19: 62%|██████▏ | 62/100 [00:00<00:00, 3275.07 it/sec]
INFO - 02:38:19: 63%|██████▎ | 63/100 [00:00<00:00, 3277.69 it/sec]
INFO - 02:38:19: 64%|██████▍ | 64/100 [00:00<00:00, 3279.88 it/sec]
INFO - 02:38:19: 65%|██████▌ | 65/100 [00:00<00:00, 3279.32 it/sec]
INFO - 02:38:19: 66%|██████▌ | 66/100 [00:00<00:00, 3281.50 it/sec]
INFO - 02:38:19: 67%|██████▋ | 67/100 [00:00<00:00, 3284.77 it/sec]
INFO - 02:38:19: 68%|██████▊ | 68/100 [00:00<00:00, 3287.98 it/sec]
INFO - 02:38:19: 69%|██████▉ | 69/100 [00:00<00:00, 3252.68 it/sec]
INFO - 02:38:19: 70%|███████ | 70/100 [00:00<00:00, 3234.81 it/sec]
INFO - 02:38:19: 71%|███████ | 71/100 [00:00<00:00, 3222.37 it/sec]
INFO - 02:38:19: 72%|███████▏ | 72/100 [00:00<00:00, 3221.23 it/sec]
INFO - 02:38:19: 73%|███████▎ | 73/100 [00:00<00:00, 3218.55 it/sec]
INFO - 02:38:19: 74%|███████▍ | 74/100 [00:00<00:00, 3216.06 it/sec]
INFO - 02:38:19: 75%|███████▌ | 75/100 [00:00<00:00, 3202.67 it/sec]
INFO - 02:38:19: 76%|███████▌ | 76/100 [00:00<00:00, 3199.16 it/sec]
INFO - 02:38:19: 77%|███████▋ | 77/100 [00:00<00:00, 3198.40 it/sec]
INFO - 02:38:19: 78%|███████▊ | 78/100 [00:00<00:00, 3200.51 it/sec]
INFO - 02:38:19: 79%|███████▉ | 79/100 [00:00<00:00, 3203.21 it/sec]
INFO - 02:38:19: 80%|████████ | 80/100 [00:00<00:00, 3202.37 it/sec]
INFO - 02:38:19: 81%|████████ | 81/100 [00:00<00:00, 3205.23 it/sec]
INFO - 02:38:19: 82%|████████▏ | 82/100 [00:00<00:00, 3208.54 it/sec]
INFO - 02:38:19: 83%|████████▎ | 83/100 [00:00<00:00, 3209.20 it/sec]
INFO - 02:38:19: 84%|████████▍ | 84/100 [00:00<00:00, 3211.21 it/sec]
INFO - 02:38:19: 85%|████████▌ | 85/100 [00:00<00:00, 3210.73 it/sec]
INFO - 02:38:19: 86%|████████▌ | 86/100 [00:00<00:00, 3209.54 it/sec]
INFO - 02:38:19: 87%|████████▋ | 87/100 [00:00<00:00, 3211.11 it/sec]
INFO - 02:38:19: 88%|████████▊ | 88/100 [00:00<00:00, 3213.83 it/sec]
INFO - 02:38:19: 89%|████████▉ | 89/100 [00:00<00:00, 3217.16 it/sec]
INFO - 02:38:19: 90%|█████████ | 90/100 [00:00<00:00, 3215.89 it/sec]
INFO - 02:38:19: 91%|█████████ | 91/100 [00:00<00:00, 3216.17 it/sec]
INFO - 02:38:19: 92%|█████████▏| 92/100 [00:00<00:00, 3218.48 it/sec]
INFO - 02:38:19: 93%|█████████▎| 93/100 [00:00<00:00, 3218.69 it/sec]
INFO - 02:38:19: 94%|█████████▍| 94/100 [00:00<00:00, 3220.98 it/sec]
INFO - 02:38:19: 95%|█████████▌| 95/100 [00:00<00:00, 3224.04 it/sec]
INFO - 02:38:19: 96%|█████████▌| 96/100 [00:00<00:00, 3224.99 it/sec]
INFO - 02:38:19: 97%|█████████▋| 97/100 [00:00<00:00, 3226.98 it/sec]
INFO - 02:38:19: 98%|█████████▊| 98/100 [00:00<00:00, 3229.84 it/sec]
INFO - 02:38:19: 99%|█████████▉| 99/100 [00:00<00:00, 3230.73 it/sec]
INFO - 02:38:19: 100%|██████████| 100/100 [00:00<00:00, 3230.71 it/sec]
INFO - 02:38:19: *** End Sampling execution (time: 0:00:00.032183) ***
Settings#
The RBFRegressor
has many options
defined in the RBFRegressor_Settings
Pydantic model.
Function#
The default RBF is the multiquadratic function \(\sqrt{(r/\epsilon)^2 + 1}\)
depending on a radius \(r\) representing a distance between two points
and an adjustable constant \(\epsilon\).
The RBF can be changed using the function
option,
which can be either an RBF
:
model = create_regression_model("RBFRegressor", training_dataset, function=RBF.GAUSSIAN)
model.learn()
predicted_output_data_g = model.predict(input_data).ravel()
or a Python function:
def rbf(self, r: float) -> float:
"""Evaluate a cubic RBF.
An RBF must take 2 arguments, namely ``(self, r)``.
Args:
r: The radius.
Returns:
The RBF value.
"""
return r**3
model = create_regression_model("RBFRegressor", training_dataset, function=rbf)
model.learn()
predicted_output_data_c = model.predict(input_data).ravel()
We can see that the predictions are different:
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_g, label="Regression - Gaussian RBF")
plt.plot(input_data.ravel(), predicted_output_data_c, label="Regression - Cubic RBF")
plt.grid()
plt.legend()
plt.show()

Epsilon#
Some RBFs depend on an epsilon
parameter
whose default value is the average distance between input data.
This is the case of "multiquadric"
, "gaussian"
and "inverse"
RBFs.
For example,
we can train a first multiquadric RBF model with an epsilon
set to 0.5
model = create_regression_model("RBFRegressor", training_dataset, epsilon=0.5)
model.learn()
predicted_output_data_1 = model.predict(input_data).ravel()
a second one with an epsilon
set to 1.0:
model = create_regression_model("RBFRegressor", training_dataset, epsilon=1.0)
model.learn()
predicted_output_data_2 = model.predict(input_data).ravel()
and a last one with an epsilon
set to 2.0:
model = create_regression_model("RBFRegressor", training_dataset, epsilon=2.0)
model.learn()
predicted_output_data_3 = model.predict(input_data).ravel()
and see that this parameter represents the regularity of the regression model:
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_1, label="Regression - Epsilon(0.5)")
plt.plot(input_data.ravel(), predicted_output_data_2, label="Regression - Epsilon(1)")
plt.plot(input_data.ravel(), predicted_output_data_3, label="Regression - Epsilon(2)")
plt.grid()
plt.legend()
plt.show()

Smooth#
By default,
an RBF model interpolates the training points.
This is parametrized by the smooth
option which is set to 0.
We can increase the smoothness of the model by increasing this value:
model = create_regression_model("RBFRegressor", training_dataset, smooth=0.1)
model.learn()
predicted_output_data_ = model.predict(input_data).ravel()
and see that the model is not interpolating:
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - Smooth")
plt.grid()
plt.legend()
plt.show()

Thin plate spline (TPS)#
TPS regression is a specific case of RBF regression
where the RBF is the thin plate radial basis function for \(r^2\log(r)\).
The TPSRegressor
class
deriving from RBFRegressor
implements this case:
model = create_regression_model("TPSRegressor", training_dataset)
model.learn()
predicted_output_data_ = model.predict(input_data).ravel()
We can see that the difference between this model and the default multiquadric RBF model:
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - TPS")
plt.grid()
plt.legend()
plt.show()

The TPSRegressor
can be customized with the TPSRegressor_Settings
.
Total running time of the script: (0 minutes 0.566 seconds)