Linear regression#

A LinearRegressor is a linear regression model based on scikit-learn.

See also

You can find more information about building linear models with scikit-learn on this page.

from __future__ import annotations

from matplotlib import pyplot as plt
from numpy import array

from gemseo import configure_logger
from gemseo import create_design_space
from gemseo import create_discipline
from gemseo import sample_disciplines
from gemseo.mlearning import create_regression_model

configure_logger()
<RootLogger root (INFO)>

Problem#

In this example, we represent the function \(f(x)=(6x-2)^2\sin(12x-4)\) [FSK08] by the AnalyticDiscipline

discipline = create_discipline(
    "AnalyticDiscipline",
    name="f",
    expressions={"y": "(6*x-2)**2*sin(12*x-4)"},
)

and seek to approximate it over the input space

input_space = create_design_space()
input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0)

To do this, we create a training dataset with 6 equispaced points:

training_dataset = sample_disciplines(
    [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6
)
WARNING - 00:17:08: No coupling in MDA, switching chain_linearize to True.
   INFO - 00:17:08:
   INFO - 00:17:08: *** Start Sampling execution ***
   INFO - 00:17:08: Sampling
   INFO - 00:17:08:    Disciplines: f
   INFO - 00:17:08:    MDO formulation: MDF
   INFO - 00:17:08: Running the algorithm PYDOE_FULLFACT:
   INFO - 00:17:08:     17%|█▋        | 1/6 [00:00<00:00, 558.12 it/sec]
   INFO - 00:17:08:     33%|███▎      | 2/6 [00:00<00:00, 889.75 it/sec]
   INFO - 00:17:08:     50%|█████     | 3/6 [00:00<00:00, 1156.62 it/sec]
   INFO - 00:17:08:     67%|██████▋   | 4/6 [00:00<00:00, 1377.44 it/sec]
   INFO - 00:17:08:     83%|████████▎ | 5/6 [00:00<00:00, 1562.82 it/sec]
   INFO - 00:17:08:    100%|██████████| 6/6 [00:00<00:00, 1723.69 it/sec]
   INFO - 00:17:08: *** End Sampling execution (time: 0:00:00.004825) ***

Basics#

Training#

Then, we train a linear regression model from these samples:

model = create_regression_model("LinearRegressor", training_dataset)
model.learn()

Prediction#

Once it is built, we can predict the output value of \(f\) at a new input point:

input_value = {"x": array([0.65])}
output_value = model.predict(input_value)
output_value
{'y': array([3.29457456])}

as well as its Jacobian value:

jacobian_value = model.predict_jacobian(input_value)
jacobian_value
{'y': {'x': array([[7.26002643]])}}

Plotting#

Of course, you can see that the linear model is no good at all here:

test_dataset = sample_disciplines(
    [discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100
)
input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy()
reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel()
predicted_output_data = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.grid()
plt.legend()
plt.show()
plot linear regression
WARNING - 00:17:08: No coupling in MDA, switching chain_linearize to True.
   INFO - 00:17:08:
   INFO - 00:17:08: *** Start Sampling execution ***
   INFO - 00:17:08: Sampling
   INFO - 00:17:08:    Disciplines: f
   INFO - 00:17:08:    MDO formulation: MDF
   INFO - 00:17:08: Running the algorithm PYDOE_FULLFACT:
   INFO - 00:17:08:      1%|          | 1/100 [00:00<00:00, 2928.98 it/sec]
   INFO - 00:17:08:      2%|▏         | 2/100 [00:00<00:00, 2809.31 it/sec]
   INFO - 00:17:08:      3%|▎         | 3/100 [00:00<00:00, 2889.97 it/sec]
   INFO - 00:17:08:      4%|▍         | 4/100 [00:00<00:00, 2990.06 it/sec]
   INFO - 00:17:08:      5%|▌         | 5/100 [00:00<00:00, 3018.35 it/sec]
   INFO - 00:17:08:      6%|▌         | 6/100 [00:00<00:00, 3078.76 it/sec]
   INFO - 00:17:08:      7%|▋         | 7/100 [00:00<00:00, 3144.83 it/sec]
   INFO - 00:17:08:      8%|▊         | 8/100 [00:00<00:00, 3193.53 it/sec]
   INFO - 00:17:08:      9%|▉         | 9/100 [00:00<00:00, 3239.68 it/sec]
   INFO - 00:17:08:     10%|█         | 10/100 [00:00<00:00, 3274.24 it/sec]
   INFO - 00:17:08:     11%|█         | 11/100 [00:00<00:00, 3297.88 it/sec]
   INFO - 00:17:08:     12%|█▏        | 12/100 [00:00<00:00, 3320.03 it/sec]
   INFO - 00:17:08:     13%|█▎        | 13/100 [00:00<00:00, 3340.23 it/sec]
   INFO - 00:17:08:     14%|█▍        | 14/100 [00:00<00:00, 3362.36 it/sec]
   INFO - 00:17:08:     15%|█▌        | 15/100 [00:00<00:00, 3374.88 it/sec]
   INFO - 00:17:08:     16%|█▌        | 16/100 [00:00<00:00, 3391.05 it/sec]
   INFO - 00:17:08:     17%|█▋        | 17/100 [00:00<00:00, 3396.04 it/sec]
   INFO - 00:17:08:     18%|█▊        | 18/100 [00:00<00:00, 3408.93 it/sec]
   INFO - 00:17:08:     19%|█▉        | 19/100 [00:00<00:00, 3401.85 it/sec]
   INFO - 00:17:08:     20%|██        | 20/100 [00:00<00:00, 3396.47 it/sec]
   INFO - 00:17:08:     21%|██        | 21/100 [00:00<00:00, 3406.44 it/sec]
   INFO - 00:17:08:     22%|██▏       | 22/100 [00:00<00:00, 3418.72 it/sec]
   INFO - 00:17:08:     23%|██▎       | 23/100 [00:00<00:00, 3429.03 it/sec]
   INFO - 00:17:08:     24%|██▍       | 24/100 [00:00<00:00, 3438.19 it/sec]
   INFO - 00:17:08:     25%|██▌       | 25/100 [00:00<00:00, 3446.88 it/sec]
   INFO - 00:17:08:     26%|██▌       | 26/100 [00:00<00:00, 3451.88 it/sec]
   INFO - 00:17:08:     27%|██▋       | 27/100 [00:00<00:00, 3459.70 it/sec]
   INFO - 00:17:08:     28%|██▊       | 28/100 [00:00<00:00, 3461.56 it/sec]
   INFO - 00:17:08:     29%|██▉       | 29/100 [00:00<00:00, 3465.68 it/sec]
   INFO - 00:17:08:     30%|███       | 30/100 [00:00<00:00, 3471.82 it/sec]
   INFO - 00:17:08:     31%|███       | 31/100 [00:00<00:00, 3471.46 it/sec]
   INFO - 00:17:08:     32%|███▏      | 32/100 [00:00<00:00, 3469.50 it/sec]
   INFO - 00:17:08:     33%|███▎      | 33/100 [00:00<00:00, 3472.28 it/sec]
   INFO - 00:17:08:     34%|███▍      | 34/100 [00:00<00:00, 3464.94 it/sec]
   INFO - 00:17:08:     35%|███▌      | 35/100 [00:00<00:00, 3469.81 it/sec]
   INFO - 00:17:08:     36%|███▌      | 36/100 [00:00<00:00, 3472.98 it/sec]
   INFO - 00:17:08:     37%|███▋      | 37/100 [00:00<00:00, 3475.29 it/sec]
   INFO - 00:17:08:     38%|███▊      | 38/100 [00:00<00:00, 3478.24 it/sec]
   INFO - 00:17:08:     39%|███▉      | 39/100 [00:00<00:00, 3482.31 it/sec]
   INFO - 00:17:08:     40%|████      | 40/100 [00:00<00:00, 3485.89 it/sec]
   INFO - 00:17:08:     41%|████      | 41/100 [00:00<00:00, 3491.28 it/sec]
   INFO - 00:17:08:     42%|████▏     | 42/100 [00:00<00:00, 3495.39 it/sec]
   INFO - 00:17:08:     43%|████▎     | 43/100 [00:00<00:00, 3498.71 it/sec]
   INFO - 00:17:08:     44%|████▍     | 44/100 [00:00<00:00, 3497.70 it/sec]
   INFO - 00:17:08:     45%|████▌     | 45/100 [00:00<00:00, 3499.40 it/sec]
   INFO - 00:17:08:     46%|████▌     | 46/100 [00:00<00:00, 3496.65 it/sec]
   INFO - 00:17:08:     47%|████▋     | 47/100 [00:00<00:00, 3499.29 it/sec]
   INFO - 00:17:08:     48%|████▊     | 48/100 [00:00<00:00, 3492.71 it/sec]
   INFO - 00:17:08:     49%|████▉     | 49/100 [00:00<00:00, 3493.11 it/sec]
   INFO - 00:17:08:     50%|█████     | 50/100 [00:00<00:00, 3496.94 it/sec]
   INFO - 00:17:08:     51%|█████     | 51/100 [00:00<00:00, 3500.34 it/sec]
   INFO - 00:17:08:     52%|█████▏    | 52/100 [00:00<00:00, 3502.72 it/sec]
   INFO - 00:17:08:     53%|█████▎    | 53/100 [00:00<00:00, 3504.29 it/sec]
   INFO - 00:17:08:     54%|█████▍    | 54/100 [00:00<00:00, 3506.83 it/sec]
   INFO - 00:17:08:     55%|█████▌    | 55/100 [00:00<00:00, 3509.29 it/sec]
   INFO - 00:17:08:     56%|█████▌    | 56/100 [00:00<00:00, 3511.35 it/sec]
   INFO - 00:17:08:     57%|█████▋    | 57/100 [00:00<00:00, 3514.21 it/sec]
   INFO - 00:17:08:     58%|█████▊    | 58/100 [00:00<00:00, 3517.49 it/sec]
   INFO - 00:17:08:     59%|█████▉    | 59/100 [00:00<00:00, 3520.36 it/sec]
   INFO - 00:17:08:     60%|██████    | 60/100 [00:00<00:00, 3519.94 it/sec]
   INFO - 00:17:08:     61%|██████    | 61/100 [00:00<00:00, 3521.81 it/sec]
   INFO - 00:17:08:     62%|██████▏   | 62/100 [00:00<00:00, 3517.47 it/sec]
   INFO - 00:17:08:     63%|██████▎   | 63/100 [00:00<00:00, 3517.63 it/sec]
   INFO - 00:17:08:     64%|██████▍   | 64/100 [00:00<00:00, 3519.17 it/sec]
   INFO - 00:17:08:     65%|██████▌   | 65/100 [00:00<00:00, 3521.07 it/sec]
   INFO - 00:17:08:     66%|██████▌   | 66/100 [00:00<00:00, 3523.46 it/sec]
   INFO - 00:17:08:     67%|██████▋   | 67/100 [00:00<00:00, 3525.64 it/sec]
   INFO - 00:17:08:     68%|██████▊   | 68/100 [00:00<00:00, 3520.23 it/sec]
   INFO - 00:17:08:     69%|██████▉   | 69/100 [00:00<00:00, 3516.02 it/sec]
   INFO - 00:17:08:     70%|███████   | 70/100 [00:00<00:00, 3516.14 it/sec]
   INFO - 00:17:08:     71%|███████   | 71/100 [00:00<00:00, 3517.09 it/sec]
   INFO - 00:17:08:     72%|███████▏  | 72/100 [00:00<00:00, 3518.92 it/sec]
   INFO - 00:17:08:     73%|███████▎  | 73/100 [00:00<00:00, 3521.18 it/sec]
   INFO - 00:17:08:     74%|███████▍  | 74/100 [00:00<00:00, 3520.87 it/sec]
   INFO - 00:17:08:     75%|███████▌  | 75/100 [00:00<00:00, 3522.18 it/sec]
   INFO - 00:17:08:     76%|███████▌  | 76/100 [00:00<00:00, 3493.38 it/sec]
   INFO - 00:17:08:     77%|███████▋  | 77/100 [00:00<00:00, 3486.42 it/sec]
   INFO - 00:17:08:     78%|███████▊  | 78/100 [00:00<00:00, 3484.75 it/sec]
   INFO - 00:17:08:     79%|███████▉  | 79/100 [00:00<00:00, 3485.07 it/sec]
   INFO - 00:17:08:     80%|████████  | 80/100 [00:00<00:00, 3486.25 it/sec]
   INFO - 00:17:08:     81%|████████  | 81/100 [00:00<00:00, 3487.86 it/sec]
   INFO - 00:17:08:     82%|████████▏ | 82/100 [00:00<00:00, 3488.66 it/sec]
   INFO - 00:17:08:     83%|████████▎ | 83/100 [00:00<00:00, 3488.81 it/sec]
   INFO - 00:17:08:     84%|████████▍ | 84/100 [00:00<00:00, 3490.37 it/sec]
   INFO - 00:17:08:     85%|████████▌ | 85/100 [00:00<00:00, 3492.45 it/sec]
   INFO - 00:17:08:     86%|████████▌ | 86/100 [00:00<00:00, 3492.92 it/sec]
   INFO - 00:17:08:     87%|████████▋ | 87/100 [00:00<00:00, 3493.95 it/sec]
   INFO - 00:17:08:     88%|████████▊ | 88/100 [00:00<00:00, 3492.41 it/sec]
   INFO - 00:17:08:     89%|████████▉ | 89/100 [00:00<00:00, 3494.17 it/sec]
   INFO - 00:17:08:     90%|█████████ | 90/100 [00:00<00:00, 3491.89 it/sec]
   INFO - 00:17:08:     91%|█████████ | 91/100 [00:00<00:00, 3493.05 it/sec]
   INFO - 00:17:08:     92%|█████████▏| 92/100 [00:00<00:00, 3494.40 it/sec]
   INFO - 00:17:08:     93%|█████████▎| 93/100 [00:00<00:00, 3495.16 it/sec]
   INFO - 00:17:08:     94%|█████████▍| 94/100 [00:00<00:00, 3496.62 it/sec]
   INFO - 00:17:08:     95%|█████████▌| 95/100 [00:00<00:00, 3498.69 it/sec]
   INFO - 00:17:08:     96%|█████████▌| 96/100 [00:00<00:00, 3500.33 it/sec]
   INFO - 00:17:08:     97%|█████████▋| 97/100 [00:00<00:00, 3501.93 it/sec]
   INFO - 00:17:08:     98%|█████████▊| 98/100 [00:00<00:00, 3503.83 it/sec]
   INFO - 00:17:08:     99%|█████████▉| 99/100 [00:00<00:00, 3505.26 it/sec]
   INFO - 00:17:08:    100%|██████████| 100/100 [00:00<00:00, 3506.71 it/sec]
   INFO - 00:17:08: *** End Sampling execution (time: 0:00:00.030124) ***

Settings#

The LinearRegressor has many options defined in the LinearRegressor_Settings Pydantic model.

Intercept#

By default, the linear model is of the form \(a_0+a_1x_1+\ldots+a_dx_d\). You can set the option fit_intercept to False if you want a linear model of the form \(a_1x_1+\ldots+a_dx_d\):

model = create_regression_model(
    "LinearRegressor", training_dataset, fit_intercept=False, transformer={}
)
model.learn()

Warning

This notion applies in the space of transformed variables. This is the reason why we removed the default transformers by setting transformer to {}.

We can see the impact of this option in the following visualization:

predicted_output_data_ = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - No intercept")
plt.grid()
plt.legend()
plt.show()
plot linear regression

Regularization#

When the number of samples is small relative to the input dimension, regularization techniques can save you from overfitting (a model that is very good at learning but bad at generalization). The penalty_level option is a positive real number defining the degree of regularization (default: no regularization). By default, the regularization technique is the ridge penalty (l2 regularization). The technique can be replaced by the lasso penalty (l1 regularization) by setting the l2_penalty_ratio option to 0.0. When l2_penalty_ratio is between 0 and 1, the regularization technique is the elastic net penalty, i.e. a linear combination of ridge and lasso penalty parametrized by this l2_penalty_ratio.

For example, we can use the ridge penalty with a level of 1.2

model = create_regression_model("LinearRegressor", training_dataset, penalty_level=1.2)
model.learn()
predicted_output_data_ = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - Ridge(1.2)")
plt.grid()
plt.legend()
plt.show()
plot linear regression

We can see that the coefficient of the linear model is lower due to the penalty.

Note

In the case of a model with many inputs, we could have used the lasso penalty and seen that some coefficients would have been set to zero.

Total running time of the script: (0 minutes 0.337 seconds)

Gallery generated by Sphinx-Gallery