Note
Go to the end to download the full example code.
Linear regression#
A LinearRegressor
is a linear regression model
based on scikit-learn.
See also
You can find more information about building linear models with scikit-learn on this page.
from __future__ import annotations
from matplotlib import pyplot as plt
from numpy import array
from gemseo import configure_logger
from gemseo import create_design_space
from gemseo import create_discipline
from gemseo import sample_disciplines
from gemseo.mlearning import create_regression_model
configure_logger()
<RootLogger root (INFO)>
Problem#
In this example,
we represent the function \(f(x)=(6x-2)^2\sin(12x-4)\) [FSK08]
by the AnalyticDiscipline
discipline = create_discipline(
"AnalyticDiscipline",
name="f",
expressions={"y": "(6*x-2)**2*sin(12*x-4)"},
)
and seek to approximate it over the input space
input_space = create_design_space()
input_space.add_variable("x", lower_bound=0.0, upper_bound=1.0)
To do this, we create a training dataset with 6 equispaced points:
training_dataset = sample_disciplines(
[discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=6
)
WARNING - 00:17:08: No coupling in MDA, switching chain_linearize to True.
INFO - 00:17:08:
INFO - 00:17:08: *** Start Sampling execution ***
INFO - 00:17:08: Sampling
INFO - 00:17:08: Disciplines: f
INFO - 00:17:08: MDO formulation: MDF
INFO - 00:17:08: Running the algorithm PYDOE_FULLFACT:
INFO - 00:17:08: 17%|█▋ | 1/6 [00:00<00:00, 558.12 it/sec]
INFO - 00:17:08: 33%|███▎ | 2/6 [00:00<00:00, 889.75 it/sec]
INFO - 00:17:08: 50%|█████ | 3/6 [00:00<00:00, 1156.62 it/sec]
INFO - 00:17:08: 67%|██████▋ | 4/6 [00:00<00:00, 1377.44 it/sec]
INFO - 00:17:08: 83%|████████▎ | 5/6 [00:00<00:00, 1562.82 it/sec]
INFO - 00:17:08: 100%|██████████| 6/6 [00:00<00:00, 1723.69 it/sec]
INFO - 00:17:08: *** End Sampling execution (time: 0:00:00.004825) ***
Basics#
Training#
Then, we train a linear regression model from these samples:
model = create_regression_model("LinearRegressor", training_dataset)
model.learn()
Prediction#
Once it is built, we can predict the output value of \(f\) at a new input point:
input_value = {"x": array([0.65])}
output_value = model.predict(input_value)
output_value
{'y': array([3.29457456])}
as well as its Jacobian value:
jacobian_value = model.predict_jacobian(input_value)
jacobian_value
{'y': {'x': array([[7.26002643]])}}
Plotting#
Of course, you can see that the linear model is no good at all here:
test_dataset = sample_disciplines(
[discipline], input_space, "y", algo_name="PYDOE_FULLFACT", n_samples=100
)
input_data = test_dataset.get_view(variable_names=model.input_names).to_numpy()
reference_output_data = test_dataset.get_view(variable_names="y").to_numpy().ravel()
predicted_output_data = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.grid()
plt.legend()
plt.show()

WARNING - 00:17:08: No coupling in MDA, switching chain_linearize to True.
INFO - 00:17:08:
INFO - 00:17:08: *** Start Sampling execution ***
INFO - 00:17:08: Sampling
INFO - 00:17:08: Disciplines: f
INFO - 00:17:08: MDO formulation: MDF
INFO - 00:17:08: Running the algorithm PYDOE_FULLFACT:
INFO - 00:17:08: 1%| | 1/100 [00:00<00:00, 2928.98 it/sec]
INFO - 00:17:08: 2%|▏ | 2/100 [00:00<00:00, 2809.31 it/sec]
INFO - 00:17:08: 3%|▎ | 3/100 [00:00<00:00, 2889.97 it/sec]
INFO - 00:17:08: 4%|▍ | 4/100 [00:00<00:00, 2990.06 it/sec]
INFO - 00:17:08: 5%|▌ | 5/100 [00:00<00:00, 3018.35 it/sec]
INFO - 00:17:08: 6%|▌ | 6/100 [00:00<00:00, 3078.76 it/sec]
INFO - 00:17:08: 7%|▋ | 7/100 [00:00<00:00, 3144.83 it/sec]
INFO - 00:17:08: 8%|▊ | 8/100 [00:00<00:00, 3193.53 it/sec]
INFO - 00:17:08: 9%|▉ | 9/100 [00:00<00:00, 3239.68 it/sec]
INFO - 00:17:08: 10%|█ | 10/100 [00:00<00:00, 3274.24 it/sec]
INFO - 00:17:08: 11%|█ | 11/100 [00:00<00:00, 3297.88 it/sec]
INFO - 00:17:08: 12%|█▏ | 12/100 [00:00<00:00, 3320.03 it/sec]
INFO - 00:17:08: 13%|█▎ | 13/100 [00:00<00:00, 3340.23 it/sec]
INFO - 00:17:08: 14%|█▍ | 14/100 [00:00<00:00, 3362.36 it/sec]
INFO - 00:17:08: 15%|█▌ | 15/100 [00:00<00:00, 3374.88 it/sec]
INFO - 00:17:08: 16%|█▌ | 16/100 [00:00<00:00, 3391.05 it/sec]
INFO - 00:17:08: 17%|█▋ | 17/100 [00:00<00:00, 3396.04 it/sec]
INFO - 00:17:08: 18%|█▊ | 18/100 [00:00<00:00, 3408.93 it/sec]
INFO - 00:17:08: 19%|█▉ | 19/100 [00:00<00:00, 3401.85 it/sec]
INFO - 00:17:08: 20%|██ | 20/100 [00:00<00:00, 3396.47 it/sec]
INFO - 00:17:08: 21%|██ | 21/100 [00:00<00:00, 3406.44 it/sec]
INFO - 00:17:08: 22%|██▏ | 22/100 [00:00<00:00, 3418.72 it/sec]
INFO - 00:17:08: 23%|██▎ | 23/100 [00:00<00:00, 3429.03 it/sec]
INFO - 00:17:08: 24%|██▍ | 24/100 [00:00<00:00, 3438.19 it/sec]
INFO - 00:17:08: 25%|██▌ | 25/100 [00:00<00:00, 3446.88 it/sec]
INFO - 00:17:08: 26%|██▌ | 26/100 [00:00<00:00, 3451.88 it/sec]
INFO - 00:17:08: 27%|██▋ | 27/100 [00:00<00:00, 3459.70 it/sec]
INFO - 00:17:08: 28%|██▊ | 28/100 [00:00<00:00, 3461.56 it/sec]
INFO - 00:17:08: 29%|██▉ | 29/100 [00:00<00:00, 3465.68 it/sec]
INFO - 00:17:08: 30%|███ | 30/100 [00:00<00:00, 3471.82 it/sec]
INFO - 00:17:08: 31%|███ | 31/100 [00:00<00:00, 3471.46 it/sec]
INFO - 00:17:08: 32%|███▏ | 32/100 [00:00<00:00, 3469.50 it/sec]
INFO - 00:17:08: 33%|███▎ | 33/100 [00:00<00:00, 3472.28 it/sec]
INFO - 00:17:08: 34%|███▍ | 34/100 [00:00<00:00, 3464.94 it/sec]
INFO - 00:17:08: 35%|███▌ | 35/100 [00:00<00:00, 3469.81 it/sec]
INFO - 00:17:08: 36%|███▌ | 36/100 [00:00<00:00, 3472.98 it/sec]
INFO - 00:17:08: 37%|███▋ | 37/100 [00:00<00:00, 3475.29 it/sec]
INFO - 00:17:08: 38%|███▊ | 38/100 [00:00<00:00, 3478.24 it/sec]
INFO - 00:17:08: 39%|███▉ | 39/100 [00:00<00:00, 3482.31 it/sec]
INFO - 00:17:08: 40%|████ | 40/100 [00:00<00:00, 3485.89 it/sec]
INFO - 00:17:08: 41%|████ | 41/100 [00:00<00:00, 3491.28 it/sec]
INFO - 00:17:08: 42%|████▏ | 42/100 [00:00<00:00, 3495.39 it/sec]
INFO - 00:17:08: 43%|████▎ | 43/100 [00:00<00:00, 3498.71 it/sec]
INFO - 00:17:08: 44%|████▍ | 44/100 [00:00<00:00, 3497.70 it/sec]
INFO - 00:17:08: 45%|████▌ | 45/100 [00:00<00:00, 3499.40 it/sec]
INFO - 00:17:08: 46%|████▌ | 46/100 [00:00<00:00, 3496.65 it/sec]
INFO - 00:17:08: 47%|████▋ | 47/100 [00:00<00:00, 3499.29 it/sec]
INFO - 00:17:08: 48%|████▊ | 48/100 [00:00<00:00, 3492.71 it/sec]
INFO - 00:17:08: 49%|████▉ | 49/100 [00:00<00:00, 3493.11 it/sec]
INFO - 00:17:08: 50%|█████ | 50/100 [00:00<00:00, 3496.94 it/sec]
INFO - 00:17:08: 51%|█████ | 51/100 [00:00<00:00, 3500.34 it/sec]
INFO - 00:17:08: 52%|█████▏ | 52/100 [00:00<00:00, 3502.72 it/sec]
INFO - 00:17:08: 53%|█████▎ | 53/100 [00:00<00:00, 3504.29 it/sec]
INFO - 00:17:08: 54%|█████▍ | 54/100 [00:00<00:00, 3506.83 it/sec]
INFO - 00:17:08: 55%|█████▌ | 55/100 [00:00<00:00, 3509.29 it/sec]
INFO - 00:17:08: 56%|█████▌ | 56/100 [00:00<00:00, 3511.35 it/sec]
INFO - 00:17:08: 57%|█████▋ | 57/100 [00:00<00:00, 3514.21 it/sec]
INFO - 00:17:08: 58%|█████▊ | 58/100 [00:00<00:00, 3517.49 it/sec]
INFO - 00:17:08: 59%|█████▉ | 59/100 [00:00<00:00, 3520.36 it/sec]
INFO - 00:17:08: 60%|██████ | 60/100 [00:00<00:00, 3519.94 it/sec]
INFO - 00:17:08: 61%|██████ | 61/100 [00:00<00:00, 3521.81 it/sec]
INFO - 00:17:08: 62%|██████▏ | 62/100 [00:00<00:00, 3517.47 it/sec]
INFO - 00:17:08: 63%|██████▎ | 63/100 [00:00<00:00, 3517.63 it/sec]
INFO - 00:17:08: 64%|██████▍ | 64/100 [00:00<00:00, 3519.17 it/sec]
INFO - 00:17:08: 65%|██████▌ | 65/100 [00:00<00:00, 3521.07 it/sec]
INFO - 00:17:08: 66%|██████▌ | 66/100 [00:00<00:00, 3523.46 it/sec]
INFO - 00:17:08: 67%|██████▋ | 67/100 [00:00<00:00, 3525.64 it/sec]
INFO - 00:17:08: 68%|██████▊ | 68/100 [00:00<00:00, 3520.23 it/sec]
INFO - 00:17:08: 69%|██████▉ | 69/100 [00:00<00:00, 3516.02 it/sec]
INFO - 00:17:08: 70%|███████ | 70/100 [00:00<00:00, 3516.14 it/sec]
INFO - 00:17:08: 71%|███████ | 71/100 [00:00<00:00, 3517.09 it/sec]
INFO - 00:17:08: 72%|███████▏ | 72/100 [00:00<00:00, 3518.92 it/sec]
INFO - 00:17:08: 73%|███████▎ | 73/100 [00:00<00:00, 3521.18 it/sec]
INFO - 00:17:08: 74%|███████▍ | 74/100 [00:00<00:00, 3520.87 it/sec]
INFO - 00:17:08: 75%|███████▌ | 75/100 [00:00<00:00, 3522.18 it/sec]
INFO - 00:17:08: 76%|███████▌ | 76/100 [00:00<00:00, 3493.38 it/sec]
INFO - 00:17:08: 77%|███████▋ | 77/100 [00:00<00:00, 3486.42 it/sec]
INFO - 00:17:08: 78%|███████▊ | 78/100 [00:00<00:00, 3484.75 it/sec]
INFO - 00:17:08: 79%|███████▉ | 79/100 [00:00<00:00, 3485.07 it/sec]
INFO - 00:17:08: 80%|████████ | 80/100 [00:00<00:00, 3486.25 it/sec]
INFO - 00:17:08: 81%|████████ | 81/100 [00:00<00:00, 3487.86 it/sec]
INFO - 00:17:08: 82%|████████▏ | 82/100 [00:00<00:00, 3488.66 it/sec]
INFO - 00:17:08: 83%|████████▎ | 83/100 [00:00<00:00, 3488.81 it/sec]
INFO - 00:17:08: 84%|████████▍ | 84/100 [00:00<00:00, 3490.37 it/sec]
INFO - 00:17:08: 85%|████████▌ | 85/100 [00:00<00:00, 3492.45 it/sec]
INFO - 00:17:08: 86%|████████▌ | 86/100 [00:00<00:00, 3492.92 it/sec]
INFO - 00:17:08: 87%|████████▋ | 87/100 [00:00<00:00, 3493.95 it/sec]
INFO - 00:17:08: 88%|████████▊ | 88/100 [00:00<00:00, 3492.41 it/sec]
INFO - 00:17:08: 89%|████████▉ | 89/100 [00:00<00:00, 3494.17 it/sec]
INFO - 00:17:08: 90%|█████████ | 90/100 [00:00<00:00, 3491.89 it/sec]
INFO - 00:17:08: 91%|█████████ | 91/100 [00:00<00:00, 3493.05 it/sec]
INFO - 00:17:08: 92%|█████████▏| 92/100 [00:00<00:00, 3494.40 it/sec]
INFO - 00:17:08: 93%|█████████▎| 93/100 [00:00<00:00, 3495.16 it/sec]
INFO - 00:17:08: 94%|█████████▍| 94/100 [00:00<00:00, 3496.62 it/sec]
INFO - 00:17:08: 95%|█████████▌| 95/100 [00:00<00:00, 3498.69 it/sec]
INFO - 00:17:08: 96%|█████████▌| 96/100 [00:00<00:00, 3500.33 it/sec]
INFO - 00:17:08: 97%|█████████▋| 97/100 [00:00<00:00, 3501.93 it/sec]
INFO - 00:17:08: 98%|█████████▊| 98/100 [00:00<00:00, 3503.83 it/sec]
INFO - 00:17:08: 99%|█████████▉| 99/100 [00:00<00:00, 3505.26 it/sec]
INFO - 00:17:08: 100%|██████████| 100/100 [00:00<00:00, 3506.71 it/sec]
INFO - 00:17:08: *** End Sampling execution (time: 0:00:00.030124) ***
Settings#
The LinearRegressor
has many options
defined in the LinearRegressor_Settings
Pydantic model.
Intercept#
By default,
the linear model is of the form \(a_0+a_1x_1+\ldots+a_dx_d\).
You can set the option fit_intercept
to False
if you want a linear model of the form \(a_1x_1+\ldots+a_dx_d\):
model = create_regression_model(
"LinearRegressor", training_dataset, fit_intercept=False, transformer={}
)
model.learn()
Warning
This notion applies in the space of transformed variables.
This is the reason why
we removed the default transformers by setting transformer
to {}
.
We can see the impact of this option in the following visualization:
predicted_output_data_ = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - No intercept")
plt.grid()
plt.legend()
plt.show()

Regularization#
When the number of samples is small relative to the input dimension,
regularization techniques can save you from overfitting
(a model that is very good at learning but bad at generalization).
The penalty_level
option is a positive real number
defining the degree of regularization (default: no regularization).
By default,
the regularization technique is the ridge penalty (l2 regularization).
The technique can be replaced by the lasso penalty (l1 regularization)
by setting the l2_penalty_ratio
option to 0.0
.
When l2_penalty_ratio
is between 0 and 1,
the regularization technique is the elastic net penalty,
i.e. a linear combination of ridge and lasso penalty
parametrized by this l2_penalty_ratio
.
For example, we can use the ridge penalty with a level of 1.2
model = create_regression_model("LinearRegressor", training_dataset, penalty_level=1.2)
model.learn()
predicted_output_data_ = model.predict(input_data).ravel()
plt.plot(input_data.ravel(), reference_output_data, label="Reference")
plt.plot(input_data.ravel(), predicted_output_data, label="Regression - Basics")
plt.plot(input_data.ravel(), predicted_output_data_, label="Regression - Ridge(1.2)")
plt.grid()
plt.legend()
plt.show()

We can see that the coefficient of the linear model is lower due to the penalty.
Note
In the case of a model with many inputs, we could have used the lasso penalty and seen that some coefficients would have been set to zero.
Total running time of the script: (0 minutes 0.337 seconds)