Polynomial regression#

A PolynomialRegressor is a polynomial regression model based on a LinearRegressor. This design choice was made because a polynomial regression model is a generalized linear model whose basis functions are monomials. Thus, a PolynomialRegressor benefits from the same settings as LinearRegressor: offset can be set to zero and regularization techniques can be used.

See also

You will find more information about these settings in the example about the linear regression model.

from __future__ import annotations

from matplotlib import pyplot as plt
from numpy import array

from gemseo import configure_logger
from gemseo import create_design_space
from gemseo import create_discipline
from gemseo import sample_disciplines
from gemseo.mlearning import create_regression_model

configure_logger()
<RootLogger root (INFO)>

Problem#

In this example, we represent the function \(f(x)=(6x-2)^2\sin(12x-4)\) [FSK08] by the AnalyticDiscipline

discipline = create_discipline(
    "AnalyticDiscipline",
    name="f",
    expressions={"y": "(6*x-2)**2*sin(12*x-4)"},
)

and seek to approximate it over the input space

input_space = create_design_space()
input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0)

To do this, we create a training dataset with 6 equispaced points:

training_dataset = sample_disciplines(
    [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6
)
WARNING - 00:17:12: No coupling in MDA, switching chain_linearize to True.
   INFO - 00:17:12:
   INFO - 00:17:12: *** Start Sampling execution ***
   INFO - 00:17:12: Sampling
   INFO - 00:17:12:    Disciplines: f
   INFO - 00:17:12:    MDO formulation: MDF
   INFO - 00:17:12: Running the algorithm PYDOE_FULLFACT:
   INFO - 00:17:12:     17%|█▋        | 1/6 [00:00<00:00, 569.11 it/sec]
   INFO - 00:17:12:     33%|███▎      | 2/6 [00:00<00:00, 926.82 it/sec]
   INFO - 00:17:12:     50%|█████     | 3/6 [00:00<00:00, 1203.99 it/sec]
   INFO - 00:17:12:     67%|██████▋   | 4/6 [00:00<00:00, 1419.15 it/sec]
   INFO - 00:17:12:     83%|████████▎ | 5/6 [00:00<00:00, 1603.20 it/sec]
   INFO - 00:17:12:    100%|██████████| 6/6 [00:00<00:00, 1763.42 it/sec]
   INFO - 00:17:12: *** End Sampling execution (time: 0:00:00.004723) ***

Basics#

Training#

Then, we train a polynomial regression model with a degree of 2 (default) from these samples:

model = create_regression_model("PolynomialRegressor", training_dataset)
model.learn()

Prediction#

Once it is built, we can predict the output value of \(f\) at a new input point:

input_value = {"x": array([0.65])}
output_value = model.predict(input_value)
output_value
{'y': array([-0.90980781])}

as well as its Jacobian value:

jacobian_value = model.predict_jacobian(input_value)
jacobian_value
{'y': {'x': array([[20.65451891]])}}

Plotting#

Of course, you can see that the quadratic model is no good at all here:

test_dataset = sample_disciplines(
    [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100
)
input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy()
reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel()
predicted_output_data = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.grid()
plt.legend()
plt.show()
plot polynomial regression
WARNING - 00:17:12: No coupling in MDA, switching chain_linearize to True.
   INFO - 00:17:12:
   INFO - 00:17:12: *** Start Sampling execution ***
   INFO - 00:17:12: Sampling
   INFO - 00:17:12:    Disciplines: f
   INFO - 00:17:12:    MDO formulation: MDF
   INFO - 00:17:12: Running the algorithm PYDOE_FULLFACT:
   INFO - 00:17:12:      1%|          | 1/100 [00:00<00:00, 3130.08 it/sec]
   INFO - 00:17:12:      2%|▏         | 2/100 [00:00<00:00, 2953.74 it/sec]
   INFO - 00:17:12:      3%|▎         | 3/100 [00:00<00:00, 2951.66 it/sec]
   INFO - 00:17:12:      4%|▍         | 4/100 [00:00<00:00, 2995.40 it/sec]
   INFO - 00:17:12:      5%|▌         | 5/100 [00:00<00:00, 3034.07 it/sec]
   INFO - 00:17:12:      6%|▌         | 6/100 [00:00<00:00, 3072.75 it/sec]
   INFO - 00:17:12:      7%|▋         | 7/100 [00:00<00:00, 3086.32 it/sec]
   INFO - 00:17:12:      8%|▊         | 8/100 [00:00<00:00, 3107.76 it/sec]
   INFO - 00:17:12:      9%|▉         | 9/100 [00:00<00:00, 3146.25 it/sec]
   INFO - 00:17:12:     10%|█         | 10/100 [00:00<00:00, 3175.10 it/sec]
   INFO - 00:17:12:     11%|█         | 11/100 [00:00<00:00, 3201.98 it/sec]
   INFO - 00:17:12:     12%|█▏        | 12/100 [00:00<00:00, 3229.91 it/sec]
   INFO - 00:17:12:     13%|█▎        | 13/100 [00:00<00:00, 3212.32 it/sec]
   INFO - 00:17:12:     14%|█▍        | 14/100 [00:00<00:00, 3227.10 it/sec]
   INFO - 00:17:12:     15%|█▌        | 15/100 [00:00<00:00, 3246.53 it/sec]
   INFO - 00:17:12:     16%|█▌        | 16/100 [00:00<00:00, 3266.91 it/sec]
   INFO - 00:17:12:     17%|█▋        | 17/100 [00:00<00:00, 3287.98 it/sec]
   INFO - 00:17:12:     18%|█▊        | 18/100 [00:00<00:00, 3306.94 it/sec]
   INFO - 00:17:12:     19%|█▉        | 19/100 [00:00<00:00, 3324.09 it/sec]
   INFO - 00:17:12:     20%|██        | 20/100 [00:00<00:00, 3340.34 it/sec]
   INFO - 00:17:12:     21%|██        | 21/100 [00:00<00:00, 3354.93 it/sec]
   INFO - 00:17:12:     22%|██▏       | 22/100 [00:00<00:00, 3366.22 it/sec]
   INFO - 00:17:12:     23%|██▎       | 23/100 [00:00<00:00, 3374.34 it/sec]
   INFO - 00:17:12:     24%|██▍       | 24/100 [00:00<00:00, 3385.01 it/sec]
   INFO - 00:17:12:     25%|██▌       | 25/100 [00:00<00:00, 3385.89 it/sec]
   INFO - 00:17:12:     26%|██▌       | 26/100 [00:00<00:00, 3393.77 it/sec]
   INFO - 00:17:12:     27%|██▋       | 27/100 [00:00<00:00, 3386.65 it/sec]
   INFO - 00:17:12:     28%|██▊       | 28/100 [00:00<00:00, 3391.88 it/sec]
   INFO - 00:17:12:     29%|██▉       | 29/100 [00:00<00:00, 3398.76 it/sec]
   INFO - 00:17:12:     30%|███       | 30/100 [00:00<00:00, 3405.85 it/sec]
   INFO - 00:17:12:     31%|███       | 31/100 [00:00<00:00, 3413.94 it/sec]
   INFO - 00:17:12:     32%|███▏      | 32/100 [00:00<00:00, 3421.04 it/sec]
   INFO - 00:17:12:     33%|███▎      | 33/100 [00:00<00:00, 3427.06 it/sec]
   INFO - 00:17:12:     34%|███▍      | 34/100 [00:00<00:00, 3432.74 it/sec]
   INFO - 00:17:12:     35%|███▌      | 35/100 [00:00<00:00, 3438.12 it/sec]
   INFO - 00:17:12:     36%|███▌      | 36/100 [00:00<00:00, 3444.93 it/sec]
   INFO - 00:17:12:     37%|███▋      | 37/100 [00:00<00:00, 3451.10 it/sec]
   INFO - 00:17:12:     38%|███▊      | 38/100 [00:00<00:00, 3457.64 it/sec]
   INFO - 00:17:12:     39%|███▉      | 39/100 [00:00<00:00, 3459.11 it/sec]
   INFO - 00:17:12:     40%|████      | 40/100 [00:00<00:00, 3462.43 it/sec]
   INFO - 00:17:12:     41%|████      | 41/100 [00:00<00:00, 3459.88 it/sec]
   INFO - 00:17:12:     42%|████▏     | 42/100 [00:00<00:00, 3461.94 it/sec]
   INFO - 00:17:12:     43%|████▎     | 43/100 [00:00<00:00, 3466.37 it/sec]
   INFO - 00:17:12:     44%|████▍     | 44/100 [00:00<00:00, 3470.21 it/sec]
   INFO - 00:17:12:     45%|████▌     | 45/100 [00:00<00:00, 3474.41 it/sec]
   INFO - 00:17:12:     46%|████▌     | 46/100 [00:00<00:00, 3477.30 it/sec]
   INFO - 00:17:12:     47%|████▋     | 47/100 [00:00<00:00, 3480.93 it/sec]
   INFO - 00:17:12:     48%|████▊     | 48/100 [00:00<00:00, 3485.63 it/sec]
   INFO - 00:17:12:     49%|████▉     | 49/100 [00:00<00:00, 3489.50 it/sec]
   INFO - 00:17:12:     50%|█████     | 50/100 [00:00<00:00, 3493.86 it/sec]
   INFO - 00:17:12:     51%|█████     | 51/100 [00:00<00:00, 3497.37 it/sec]
   INFO - 00:17:12:     52%|█████▏    | 52/100 [00:00<00:00, 3500.64 it/sec]
   INFO - 00:17:12:     53%|█████▎    | 53/100 [00:00<00:00, 3501.03 it/sec]
   INFO - 00:17:12:     54%|█████▍    | 54/100 [00:00<00:00, 3501.63 it/sec]
   INFO - 00:17:12:     55%|█████▌    | 55/100 [00:00<00:00, 3504.71 it/sec]
   INFO - 00:17:12:     56%|█████▌    | 56/100 [00:00<00:00, 3501.25 it/sec]
   INFO - 00:17:12:     57%|█████▋    | 57/100 [00:00<00:00, 3503.86 it/sec]
   INFO - 00:17:12:     58%|█████▊    | 58/100 [00:00<00:00, 3506.18 it/sec]
   INFO - 00:17:12:     59%|█████▉    | 59/100 [00:00<00:00, 3508.83 it/sec]
   INFO - 00:17:12:     60%|██████    | 60/100 [00:00<00:00, 3511.74 it/sec]
   INFO - 00:17:12:     61%|██████    | 61/100 [00:00<00:00, 3513.73 it/sec]
   INFO - 00:17:12:     62%|██████▏   | 62/100 [00:00<00:00, 3516.52 it/sec]
   INFO - 00:17:12:     63%|██████▎   | 63/100 [00:00<00:00, 3518.95 it/sec]
   INFO - 00:17:12:     64%|██████▍   | 64/100 [00:00<00:00, 3521.39 it/sec]
   INFO - 00:17:12:     65%|██████▌   | 65/100 [00:00<00:00, 3523.67 it/sec]
   INFO - 00:17:12:     66%|██████▌   | 66/100 [00:00<00:00, 3524.89 it/sec]
   INFO - 00:17:12:     67%|██████▋   | 67/100 [00:00<00:00, 3526.35 it/sec]
   INFO - 00:17:12:     68%|██████▊   | 68/100 [00:00<00:00, 3525.63 it/sec]
   INFO - 00:17:12:     69%|██████▉   | 69/100 [00:00<00:00, 3527.29 it/sec]
   INFO - 00:17:12:     70%|███████   | 70/100 [00:00<00:00, 3526.06 it/sec]
   INFO - 00:17:12:     71%|███████   | 71/100 [00:00<00:00, 3526.59 it/sec]
   INFO - 00:17:12:     72%|███████▏  | 72/100 [00:00<00:00, 3528.54 it/sec]
   INFO - 00:17:12:     73%|███████▎  | 73/100 [00:00<00:00, 3531.01 it/sec]
   INFO - 00:17:12:     74%|███████▍  | 74/100 [00:00<00:00, 3532.77 it/sec]
   INFO - 00:17:12:     75%|███████▌  | 75/100 [00:00<00:00, 3534.17 it/sec]
   INFO - 00:17:12:     76%|███████▌  | 76/100 [00:00<00:00, 3511.62 it/sec]
   INFO - 00:17:12:     77%|███████▋  | 77/100 [00:00<00:00, 3507.17 it/sec]
   INFO - 00:17:12:     78%|███████▊  | 78/100 [00:00<00:00, 3506.68 it/sec]
   INFO - 00:17:12:     79%|███████▉  | 79/100 [00:00<00:00, 3507.02 it/sec]
   INFO - 00:17:12:     80%|████████  | 80/100 [00:00<00:00, 3507.57 it/sec]
   INFO - 00:17:12:     81%|████████  | 81/100 [00:00<00:00, 3508.07 it/sec]
   INFO - 00:17:12:     82%|████████▏ | 82/100 [00:00<00:00, 3504.83 it/sec]
   INFO - 00:17:12:     83%|████████▎ | 83/100 [00:00<00:00, 3505.60 it/sec]
   INFO - 00:17:12:     84%|████████▍ | 84/100 [00:00<00:00, 3502.83 it/sec]
   INFO - 00:17:12:     85%|████████▌ | 85/100 [00:00<00:00, 3503.46 it/sec]
   INFO - 00:17:12:     86%|████████▌ | 86/100 [00:00<00:00, 3504.49 it/sec]
   INFO - 00:17:12:     87%|████████▋ | 87/100 [00:00<00:00, 3506.10 it/sec]
   INFO - 00:17:12:     88%|████████▊ | 88/100 [00:00<00:00, 3506.58 it/sec]
   INFO - 00:17:12:     89%|████████▉ | 89/100 [00:00<00:00, 3507.73 it/sec]
   INFO - 00:17:12:     90%|█████████ | 90/100 [00:00<00:00, 3508.57 it/sec]
   INFO - 00:17:12:     91%|█████████ | 91/100 [00:00<00:00, 3510.04 it/sec]
   INFO - 00:17:12:     92%|█████████▏| 92/100 [00:00<00:00, 3511.89 it/sec]
   INFO - 00:17:12:     93%|█████████▎| 93/100 [00:00<00:00, 3513.74 it/sec]
   INFO - 00:17:12:     94%|█████████▍| 94/100 [00:00<00:00, 3515.45 it/sec]
   INFO - 00:17:12:     95%|█████████▌| 95/100 [00:00<00:00, 3517.13 it/sec]
   INFO - 00:17:12:     96%|█████████▌| 96/100 [00:00<00:00, 3516.07 it/sec]
   INFO - 00:17:12:     97%|█████████▋| 97/100 [00:00<00:00, 3516.19 it/sec]
   INFO - 00:17:12:     98%|█████████▊| 98/100 [00:00<00:00, 3513.09 it/sec]
   INFO - 00:17:12:     99%|█████████▉| 99/100 [00:00<00:00, 3512.91 it/sec]
   INFO - 00:17:12:    100%|██████████| 100/100 [00:00<00:00, 3514.11 it/sec]
   INFO - 00:17:12: *** End Sampling execution (time: 0:00:00.030021) ***

Settings#

The PolynomialRegressor has many options defined in the PolynomialRegressor_Settings Pydantic model. Most of them are presented in the example about the linear regression model.

The only one we will look at here is the degree of the polynomial regression model. This information can be set with the degree keyword. For example, we can use a cubic regression model instead of a quadratic one:

model = create_regression_model("PolynomialRegressor", training_dataset, degree=3)
model.learn()

and see that this model seems to be better:

predicted_output_data_ = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - Degree(3)")
plt.grid()
plt.legend()
plt.show()
plot polynomial regression

Total running time of the script: (0 minutes 0.243 seconds)

Gallery generated by Sphinx-Gallery