Radial basis function (RBF) regression#

An RBFRegressor is an RBF model based on SciPy.

See also

You can find more information about RBF models on this wikipedia page.

from __future__ import annotations

from matplotlib import pyplot as plt
from numpy import array

from gemseo import configure_logger
from gemseo import create_design_space
from gemseo import create_discipline
from gemseo import sample_disciplines
from gemseo.mlearning import create_regression_model
from gemseo.mlearning.regression.algos.rbf_settings import RBF

configure_logger()
<RootLogger root (INFO)>

Problem#

In this example, we represent the function \(f(x)=(6x-2)^2\sin(12x-4)\) [FSK08] by the AnalyticDiscipline

discipline = create_discipline(
    "AnalyticDiscipline",
    name="f",
    expressions={"y": "(6*x-2)**2*sin(12*x-4)"},
)

and seek to approximate it over the input space

input_space = create_design_space()
input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0)

To do this, we create a training dataset with 6 equispaced points:

training_dataset = sample_disciplines(
    [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6
)
WARNING - 02:38:19: No coupling in MDA, switching chain_linearize to True.
   INFO - 02:38:19: *** Start Sampling execution ***
   INFO - 02:38:19: Sampling
   INFO - 02:38:19:    Disciplines: f
   INFO - 02:38:19:    MDO formulation: MDF
   INFO - 02:38:19: Running the algorithm PYDOE_FULLFACT:
   INFO - 02:38:19:     17%|█▋        | 1/6 [00:00<00:00, 575.59 it/sec]
   INFO - 02:38:19:     33%|███▎      | 2/6 [00:00<00:00, 930.52 it/sec]
   INFO - 02:38:19:     50%|█████     | 3/6 [00:00<00:00, 1199.74 it/sec]
   INFO - 02:38:19:     67%|██████▋   | 4/6 [00:00<00:00, 1401.96 it/sec]
   INFO - 02:38:19:     83%|████████▎ | 5/6 [00:00<00:00, 1582.16 it/sec]
   INFO - 02:38:19:    100%|██████████| 6/6 [00:00<00:00, 1735.45 it/sec]
   INFO - 02:38:19: *** End Sampling execution (time: 0:00:00.004561) ***

Basics#

Training#

Then, we train an RBF regression model from these samples:

model = create_regression_model("RBFRegressor", training_dataset)
model.learn()

Prediction#

Once it is built, we can predict the output value of \(f\) at a new input point:

input_value = {"x": array([0.65])}
output_value = model.predict(input_value)
output_value
{'y': array([-2.16802353])}

as well as its Jacobian value:

jacobian_value = model.predict_jacobian(input_value)
jacobian_value
{'y': {'x': array([[-45.81825011]])}}

Plotting#

You can see that the RBF model is pretty good on the right, but bad on the left:

test_dataset = sample_disciplines(
    [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100
)
input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy()
reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel()
predicted_output_data = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.grid()
plt.legend()
plt.show()
plot rbf regression
WARNING - 02:38:19: No coupling in MDA, switching chain_linearize to True.
   INFO - 02:38:19: *** Start Sampling execution ***
   INFO - 02:38:19: Sampling
   INFO - 02:38:19:    Disciplines: f
   INFO - 02:38:19:    MDO formulation: MDF
   INFO - 02:38:19: Running the algorithm PYDOE_FULLFACT:
   INFO - 02:38:19:      1%|          | 1/100 [00:00<00:00, 3339.41 it/sec]
   INFO - 02:38:19:      2%|▏         | 2/100 [00:00<00:00, 3002.37 it/sec]
   INFO - 02:38:19:      3%|▎         | 3/100 [00:00<00:00, 2894.62 it/sec]
   INFO - 02:38:19:      4%|▍         | 4/100 [00:00<00:00, 2970.47 it/sec]
   INFO - 02:38:19:      5%|▌         | 5/100 [00:00<00:00, 3009.69 it/sec]
   INFO - 02:38:19:      6%|▌         | 6/100 [00:00<00:00, 2918.45 it/sec]
   INFO - 02:38:19:      7%|▋         | 7/100 [00:00<00:00, 2943.37 it/sec]
   INFO - 02:38:19:      8%|▊         | 8/100 [00:00<00:00, 2985.27 it/sec]
   INFO - 02:38:19:      9%|▉         | 9/100 [00:00<00:00, 2995.93 it/sec]
   INFO - 02:38:19:     10%|█         | 10/100 [00:00<00:00, 3033.85 it/sec]
   INFO - 02:38:19:     11%|█         | 11/100 [00:00<00:00, 3056.87 it/sec]
   INFO - 02:38:19:     12%|█▏        | 12/100 [00:00<00:00, 3061.72 it/sec]
   INFO - 02:38:19:     13%|█▎        | 13/100 [00:00<00:00, 3081.78 it/sec]
   INFO - 02:38:19:     14%|█▍        | 14/100 [00:00<00:00, 3107.06 it/sec]
   INFO - 02:38:19:     15%|█▌        | 15/100 [00:00<00:00, 3116.59 it/sec]
   INFO - 02:38:19:     16%|█▌        | 16/100 [00:00<00:00, 3134.90 it/sec]
   INFO - 02:38:19:     17%|█▋        | 17/100 [00:00<00:00, 3154.87 it/sec]
   INFO - 02:38:19:     18%|█▊        | 18/100 [00:00<00:00, 3169.63 it/sec]
   INFO - 02:38:19:     19%|█▉        | 19/100 [00:00<00:00, 3153.86 it/sec]
   INFO - 02:38:19:     20%|██        | 20/100 [00:00<00:00, 3168.02 it/sec]
   INFO - 02:38:19:     21%|██        | 21/100 [00:00<00:00, 3172.70 it/sec]
   INFO - 02:38:19:     22%|██▏       | 22/100 [00:00<00:00, 3170.08 it/sec]
   INFO - 02:38:19:     23%|██▎       | 23/100 [00:00<00:00, 3181.17 it/sec]
   INFO - 02:38:19:     24%|██▍       | 24/100 [00:00<00:00, 3194.95 it/sec]
   INFO - 02:38:19:     25%|██▌       | 25/100 [00:00<00:00, 3199.41 it/sec]
   INFO - 02:38:19:     26%|██▌       | 26/100 [00:00<00:00, 3208.82 it/sec]
   INFO - 02:38:19:     27%|██▋       | 27/100 [00:00<00:00, 3217.77 it/sec]
   INFO - 02:38:19:     28%|██▊       | 28/100 [00:00<00:00, 3225.94 it/sec]
   INFO - 02:38:19:     29%|██▉       | 29/100 [00:00<00:00, 3225.10 it/sec]
   INFO - 02:38:19:     30%|███       | 30/100 [00:00<00:00, 3234.18 it/sec]
   INFO - 02:38:19:     31%|███       | 31/100 [00:00<00:00, 3239.49 it/sec]
   INFO - 02:38:19:     32%|███▏      | 32/100 [00:00<00:00, 3235.41 it/sec]
   INFO - 02:38:19:     33%|███▎      | 33/100 [00:00<00:00, 3239.45 it/sec]
   INFO - 02:38:19:     34%|███▍      | 34/100 [00:00<00:00, 3240.61 it/sec]
   INFO - 02:38:19:     35%|███▌      | 35/100 [00:00<00:00, 3238.63 it/sec]
   INFO - 02:38:19:     36%|███▌      | 36/100 [00:00<00:00, 3240.72 it/sec]
   INFO - 02:38:19:     37%|███▋      | 37/100 [00:00<00:00, 3245.89 it/sec]
   INFO - 02:38:19:     38%|███▊      | 38/100 [00:00<00:00, 3252.19 it/sec]
   INFO - 02:38:19:     39%|███▉      | 39/100 [00:00<00:00, 3241.09 it/sec]
   INFO - 02:38:19:     40%|████      | 40/100 [00:00<00:00, 3245.17 it/sec]
   INFO - 02:38:19:     41%|████      | 41/100 [00:00<00:00, 3248.63 it/sec]
   INFO - 02:38:19:     42%|████▏     | 42/100 [00:00<00:00, 3243.68 it/sec]
   INFO - 02:38:19:     43%|████▎     | 43/100 [00:00<00:00, 3247.07 it/sec]
   INFO - 02:38:19:     44%|████▍     | 44/100 [00:00<00:00, 3251.63 it/sec]
   INFO - 02:38:19:     45%|████▌     | 45/100 [00:00<00:00, 3251.17 it/sec]
   INFO - 02:38:19:     46%|████▌     | 46/100 [00:00<00:00, 3254.85 it/sec]
   INFO - 02:38:19:     47%|████▋     | 47/100 [00:00<00:00, 3255.80 it/sec]
   INFO - 02:38:19:     48%|████▊     | 48/100 [00:00<00:00, 3253.61 it/sec]
   INFO - 02:38:19:     49%|████▉     | 49/100 [00:00<00:00, 3255.05 it/sec]
   INFO - 02:38:19:     50%|█████     | 50/100 [00:00<00:00, 3259.23 it/sec]
   INFO - 02:38:19:     51%|█████     | 51/100 [00:00<00:00, 3263.10 it/sec]
   INFO - 02:38:19:     52%|█████▏    | 52/100 [00:00<00:00, 3262.44 it/sec]
   INFO - 02:38:19:     53%|█████▎    | 53/100 [00:00<00:00, 3264.15 it/sec]
   INFO - 02:38:19:     54%|█████▍    | 54/100 [00:00<00:00, 3267.49 it/sec]
   INFO - 02:38:19:     55%|█████▌    | 55/100 [00:00<00:00, 3266.82 it/sec]
   INFO - 02:38:19:     56%|█████▌    | 56/100 [00:00<00:00, 3269.46 it/sec]
   INFO - 02:38:19:     57%|█████▋    | 57/100 [00:00<00:00, 3272.54 it/sec]
   INFO - 02:38:19:     58%|█████▊    | 58/100 [00:00<00:00, 3272.13 it/sec]
   INFO - 02:38:19:     59%|█████▉    | 59/100 [00:00<00:00, 3272.94 it/sec]
   INFO - 02:38:19:     60%|██████    | 60/100 [00:00<00:00, 3276.37 it/sec]
   INFO - 02:38:19:     61%|██████    | 61/100 [00:00<00:00, 3276.09 it/sec]
   INFO - 02:38:19:     62%|██████▏   | 62/100 [00:00<00:00, 3275.07 it/sec]
   INFO - 02:38:19:     63%|██████▎   | 63/100 [00:00<00:00, 3277.69 it/sec]
   INFO - 02:38:19:     64%|██████▍   | 64/100 [00:00<00:00, 3279.88 it/sec]
   INFO - 02:38:19:     65%|██████▌   | 65/100 [00:00<00:00, 3279.32 it/sec]
   INFO - 02:38:19:     66%|██████▌   | 66/100 [00:00<00:00, 3281.50 it/sec]
   INFO - 02:38:19:     67%|██████▋   | 67/100 [00:00<00:00, 3284.77 it/sec]
   INFO - 02:38:19:     68%|██████▊   | 68/100 [00:00<00:00, 3287.98 it/sec]
   INFO - 02:38:19:     69%|██████▉   | 69/100 [00:00<00:00, 3252.68 it/sec]
   INFO - 02:38:19:     70%|███████   | 70/100 [00:00<00:00, 3234.81 it/sec]
   INFO - 02:38:19:     71%|███████   | 71/100 [00:00<00:00, 3222.37 it/sec]
   INFO - 02:38:19:     72%|███████▏  | 72/100 [00:00<00:00, 3221.23 it/sec]
   INFO - 02:38:19:     73%|███████▎  | 73/100 [00:00<00:00, 3218.55 it/sec]
   INFO - 02:38:19:     74%|███████▍  | 74/100 [00:00<00:00, 3216.06 it/sec]
   INFO - 02:38:19:     75%|███████▌  | 75/100 [00:00<00:00, 3202.67 it/sec]
   INFO - 02:38:19:     76%|███████▌  | 76/100 [00:00<00:00, 3199.16 it/sec]
   INFO - 02:38:19:     77%|███████▋  | 77/100 [00:00<00:00, 3198.40 it/sec]
   INFO - 02:38:19:     78%|███████▊  | 78/100 [00:00<00:00, 3200.51 it/sec]
   INFO - 02:38:19:     79%|███████▉  | 79/100 [00:00<00:00, 3203.21 it/sec]
   INFO - 02:38:19:     80%|████████  | 80/100 [00:00<00:00, 3202.37 it/sec]
   INFO - 02:38:19:     81%|████████  | 81/100 [00:00<00:00, 3205.23 it/sec]
   INFO - 02:38:19:     82%|████████▏ | 82/100 [00:00<00:00, 3208.54 it/sec]
   INFO - 02:38:19:     83%|████████▎ | 83/100 [00:00<00:00, 3209.20 it/sec]
   INFO - 02:38:19:     84%|████████▍ | 84/100 [00:00<00:00, 3211.21 it/sec]
   INFO - 02:38:19:     85%|████████▌ | 85/100 [00:00<00:00, 3210.73 it/sec]
   INFO - 02:38:19:     86%|████████▌ | 86/100 [00:00<00:00, 3209.54 it/sec]
   INFO - 02:38:19:     87%|████████▋ | 87/100 [00:00<00:00, 3211.11 it/sec]
   INFO - 02:38:19:     88%|████████▊ | 88/100 [00:00<00:00, 3213.83 it/sec]
   INFO - 02:38:19:     89%|████████▉ | 89/100 [00:00<00:00, 3217.16 it/sec]
   INFO - 02:38:19:     90%|█████████ | 90/100 [00:00<00:00, 3215.89 it/sec]
   INFO - 02:38:19:     91%|█████████ | 91/100 [00:00<00:00, 3216.17 it/sec]
   INFO - 02:38:19:     92%|█████████▏| 92/100 [00:00<00:00, 3218.48 it/sec]
   INFO - 02:38:19:     93%|█████████▎| 93/100 [00:00<00:00, 3218.69 it/sec]
   INFO - 02:38:19:     94%|█████████▍| 94/100 [00:00<00:00, 3220.98 it/sec]
   INFO - 02:38:19:     95%|█████████▌| 95/100 [00:00<00:00, 3224.04 it/sec]
   INFO - 02:38:19:     96%|█████████▌| 96/100 [00:00<00:00, 3224.99 it/sec]
   INFO - 02:38:19:     97%|█████████▋| 97/100 [00:00<00:00, 3226.98 it/sec]
   INFO - 02:38:19:     98%|█████████▊| 98/100 [00:00<00:00, 3229.84 it/sec]
   INFO - 02:38:19:     99%|█████████▉| 99/100 [00:00<00:00, 3230.73 it/sec]
   INFO - 02:38:19:    100%|██████████| 100/100 [00:00<00:00, 3230.71 it/sec]
   INFO - 02:38:19: *** End Sampling execution (time: 0:00:00.032183) ***

Settings#

The RBFRegressor has many options defined in the RBFRegressor_Settings Pydantic model.

Function#

The default RBF is the multiquadratic function \(\sqrt{(r/\epsilon)^2 + 1}\) depending on a radius \(r\) representing a distance between two points and an adjustable constant \(\epsilon\). The RBF can be changed using the function option, which can be either an RBF:

model = create_regression_model("RBFRegressor", training_dataset, function=RBF.GAUSSIAN)
model.learn()
predicted_output_data_g = model.predict(input_data).ravel()

or a Python function:

def rbf(self, r: float) -> float:
    """Evaluate a cubic RBF.

    An RBF must take 2 arguments, namely ``(self, r)``.

    Args:
        r: The radius.

    Returns:
        The RBF value.
    """
    return r**3


model = create_regression_model("RBFRegressor", training_dataset, function=rbf)
model.learn()
predicted_output_data_c = model.predict(input_data).ravel()

We can see that the predictions are different:

plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_g, label="Regression - Gaussian RBF")
plt.plot(input_data.ravel(), predicted_output_data_c, label="Regression - Cubic RBF")
plt.grid()
plt.legend()
plt.show()
plot rbf regression

Epsilon#

Some RBFs depend on an epsilon parameter whose default value is the average distance between input data. This is the case of "multiquadric", "gaussian" and "inverse" RBFs. For example, we can train a first multiquadric RBF model with an epsilon set to 0.5

model = create_regression_model("RBFRegressor", training_dataset, epsilon=0.5)
model.learn()
predicted_output_data_1 = model.predict(input_data).ravel()

a second one with an epsilon set to 1.0:

model = create_regression_model("RBFRegressor", training_dataset, epsilon=1.0)
model.learn()
predicted_output_data_2 = model.predict(input_data).ravel()

and a last one with an epsilon set to 2.0:

model = create_regression_model("RBFRegressor", training_dataset, epsilon=2.0)
model.learn()
predicted_output_data_3 = model.predict(input_data).ravel()

and see that this parameter represents the regularity of the regression model:

plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_1, label="Regression - Epsilon(0.5)")
plt.plot(input_data.ravel(), predicted_output_data_2, label="Regression - Epsilon(1)")
plt.plot(input_data.ravel(), predicted_output_data_3, label="Regression - Epsilon(2)")
plt.grid()
plt.legend()
plt.show()
plot rbf regression

Smooth#

By default, an RBF model interpolates the training points. This is parametrized by the smooth option which is set to 0. We can increase the smoothness of the model by increasing this value:

model = create_regression_model("RBFRegressor", training_dataset, smooth=0.1)
model.learn()
predicted_output_data_ = model.predict(input_data).ravel()

and see that the model is not interpolating:

plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - Smooth")
plt.grid()
plt.legend()
plt.show()
plot rbf regression

Thin plate spline (TPS)#

TPS regression is a specific case of RBF regression where the RBF is the thin plate radial basis function for \(r^2\log(r)\). The TPSRegressor class deriving from RBFRegressor implements this case:

model = create_regression_model("TPSRegressor", training_dataset)
model.learn()
predicted_output_data_ = model.predict(input_data).ravel()

We can see that the difference between this model and the default multiquadric RBF model:

plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - TPS")
plt.grid()
plt.legend()
plt.show()
plot rbf regression

The TPSRegressor can be customized with the TPSRegressor_Settings.

Total running time of the script: (0 minutes 0.566 seconds)

Gallery generated by Sphinx-Gallery